mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
[improvement](python) LLM related services support Java service invocation (#484)
This commit is contained in:
@@ -104,6 +104,19 @@
|
|||||||
<version>${mockito-inline.version}</version>
|
<version>${mockito-inline.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<!--langchain4j-->
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-open-ai</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-chroma</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,11 @@ import com.google.common.collect.Lists;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||||
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
|
||||||
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
@@ -32,8 +33,8 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
@Autowired
|
|
||||||
private EmbeddingUtils embeddingUtils;
|
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||||
@@ -83,7 +84,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
.queryEmbeddings(null)
|
.queryEmbeddings(null)
|
||||||
.build();
|
.build();
|
||||||
// step2. retrieveQuery by detectSegment
|
// step2. retrieveQuery by detectSegment
|
||||||
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
|
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
|
||||||
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
|
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||||
@@ -97,7 +98,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
||||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||||
retrievals.removeIf(retrieval -> {
|
retrievals.removeIf(retrieval -> {
|
||||||
String modelIdStr = retrieval.getMetadata().get("modelId");
|
String modelIdStr = retrieval.getMetadata().get("modelId").toString();
|
||||||
if (StringUtils.isBlank(modelIdStr)) {
|
if (StringUtils.isBlank(modelIdStr)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,84 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallPromptGenerator;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlPromptGenerator;
|
||||||
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||||
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.input.Prompt;
|
||||||
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class EmbedLLMProxy implements LLMProxy {
|
||||||
|
|
||||||
|
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||||
|
|
||||||
|
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||||
|
|
||||||
|
SqlExampleLoader sqlExampleLoader = ContextUtils.getBean(SqlExampleLoader.class);
|
||||||
|
|
||||||
|
OptimizationConfig config = ContextUtils.getBean(OptimizationConfig.class);
|
||||||
|
|
||||||
|
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||||
|
config.getText2sqlCollectionName(), config.getText2sqlFewShotsNum());
|
||||||
|
|
||||||
|
String queryText = llmReq.getQueryText();
|
||||||
|
String modelName = llmReq.getSchema().getModelName();
|
||||||
|
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||||
|
List<ElementValue> linking = llmReq.getLinking();
|
||||||
|
|
||||||
|
SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class);
|
||||||
|
String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList,
|
||||||
|
linking, sqlExamples);
|
||||||
|
|
||||||
|
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||||
|
Response<AiMessage> linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
||||||
|
|
||||||
|
String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text());
|
||||||
|
|
||||||
|
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr,
|
||||||
|
llmReq.getCurrentDate(), sqlExamples);
|
||||||
|
|
||||||
|
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
||||||
|
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
|
||||||
|
|
||||||
|
LLMResp result = new LLMResp();
|
||||||
|
result.setQuery(queryText);
|
||||||
|
result.setSchemaLinkingOutput(linkingPromptStr);
|
||||||
|
result.setSchemaLinkStr(schemaLinkStr);
|
||||||
|
result.setModelName(modelName);
|
||||||
|
result.setSqlOutput(sqlResult.content().text());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FunctionResp requestFunction(FunctionReq functionReq) {
|
||||||
|
|
||||||
|
FunctionCallPromptGenerator promptGenerator = ContextUtils.getBean(FunctionCallPromptGenerator.class);
|
||||||
|
|
||||||
|
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
|
||||||
|
functionReq.getPluginConfigs());
|
||||||
|
|
||||||
|
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||||
|
|
||||||
|
String functionSelect = chatLanguageModel.generate(functionCallPrompt);
|
||||||
|
|
||||||
|
return OutputFormat.functionCallParse(functionSelect);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.prompt.InputFormat;
|
||||||
|
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
@Component
|
||||||
|
@Slf4j
|
||||||
|
public class FunctionCallPromptGenerator {
|
||||||
|
|
||||||
|
public String generateFunctionCallPrompt(String queryText, List<PluginParseConfig> toolConfigList) {
|
||||||
|
List<String> toolExplainList = toolConfigList.stream()
|
||||||
|
.map(this::constructPluginPrompt)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
String functionList = String.join(InputFormat.SEPERATOR, toolExplainList);
|
||||||
|
return constructTaskPrompt(queryText, functionList);
|
||||||
|
}
|
||||||
|
|
||||||
|
public String constructPluginPrompt(PluginParseConfig parseConfig) {
|
||||||
|
String toolName = parseConfig.getName();
|
||||||
|
String toolDescription = parseConfig.getDescription();
|
||||||
|
List<String> toolExamples = parseConfig.getExamples();
|
||||||
|
|
||||||
|
StringBuilder prompt = new StringBuilder();
|
||||||
|
prompt.append("【工具名称】\n").append(toolName).append("\n");
|
||||||
|
prompt.append("【工具描述】\n").append(toolDescription).append("\n");
|
||||||
|
prompt.append("【工具适用问题示例】\n");
|
||||||
|
for (String example : toolExamples) {
|
||||||
|
prompt.append(example).append("\n");
|
||||||
|
}
|
||||||
|
return prompt.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
public String constructTaskPrompt(String queryText, String functionList) {
|
||||||
|
String instruction = String.format("问题为:%s\n请根据问题和工具的描述,选择对应的工具,完成任务。"
|
||||||
|
+ "请注意,只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据),"
|
||||||
|
+ "并给出最终选择,输出格式为json,key为’分析过程‘, ’选择工具‘", queryText);
|
||||||
|
|
||||||
|
return String.format("工具选择如下:\n\n%s\n\n【任务说明】\n%s", functionList, instruction);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.sql.llm.prompt;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class InputFormat {
|
||||||
|
|
||||||
|
public static final String SEPERATOR = "\n\n";
|
||||||
|
|
||||||
|
public static String format(String template, List<String> templateKey,
|
||||||
|
List<Map<String, String>> toFormatList) {
|
||||||
|
List<String> result = new ArrayList<>();
|
||||||
|
|
||||||
|
for (Map<String, String> formatItem : toFormatList) {
|
||||||
|
Map<String, String> retrievalMeta = subDict(formatItem, templateKey);
|
||||||
|
result.add(format(template, retrievalMeta));
|
||||||
|
}
|
||||||
|
|
||||||
|
return String.join(SEPERATOR, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String format(String input, Map<String, String> replacements) {
|
||||||
|
for (Map.Entry<String, String> entry : replacements.entrySet()) {
|
||||||
|
input = input.replace(entry.getKey(), entry.getValue());
|
||||||
|
}
|
||||||
|
return input;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Map<String, String> subDict(Map<String, String> dict, List<String> keys) {
|
||||||
|
Map<String, String> subDict = new HashMap<>();
|
||||||
|
for (String key : keys) {
|
||||||
|
if (dict.containsKey(key)) {
|
||||||
|
subDict.put(key, dict.get(key));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return subDict;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.sql.llm.prompt;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.regex.Matcher;
|
||||||
|
import java.util.regex.Pattern;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
/***
|
||||||
|
* output format
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class OutputFormat {
|
||||||
|
|
||||||
|
public static final String PATTERN = "\\{[^{}]+\\}";
|
||||||
|
|
||||||
|
public static String schemaLinkParse(String schemaLinkOutput) {
|
||||||
|
try {
|
||||||
|
schemaLinkOutput = schemaLinkOutput.trim();
|
||||||
|
String pattern = "Schema_links:(.*)";
|
||||||
|
Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL);
|
||||||
|
Matcher matcher = regexPattern.matcher(schemaLinkOutput);
|
||||||
|
if (matcher.find()) {
|
||||||
|
schemaLinkOutput = matcher.group(1).trim();
|
||||||
|
} else {
|
||||||
|
schemaLinkOutput = null;
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("", e);
|
||||||
|
schemaLinkOutput = null;
|
||||||
|
}
|
||||||
|
return schemaLinkOutput;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static FunctionResp functionCallParse(String llmOutput) {
|
||||||
|
try {
|
||||||
|
String[] findResult = llmOutput.split(PATTERN);
|
||||||
|
String result = findResult[0].trim();
|
||||||
|
|
||||||
|
Map<String, String> resultDict = JsonUtil.toMap(result, String.class, String.class);
|
||||||
|
log.info("result:{},resultDict:{}", result, resultDict);
|
||||||
|
|
||||||
|
String selection = resultDict.get("选择工具");
|
||||||
|
FunctionResp resp = new FunctionResp();
|
||||||
|
resp.setToolSelection(selection);
|
||||||
|
return resp;
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("", e);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.sql.llm.prompt;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SqlExample {
|
||||||
|
|
||||||
|
@JsonProperty("currentDate")
|
||||||
|
private String currentDate;
|
||||||
|
|
||||||
|
@JsonProperty("tableName")
|
||||||
|
private String tableName;
|
||||||
|
|
||||||
|
@JsonProperty("fieldsList")
|
||||||
|
private String fieldsList;
|
||||||
|
|
||||||
|
@JsonProperty("question")
|
||||||
|
private String question;
|
||||||
|
|
||||||
|
@JsonProperty("priorSchemaLinks")
|
||||||
|
private String priorSchemaLinks;
|
||||||
|
|
||||||
|
@JsonProperty("analysis")
|
||||||
|
private String analysis;
|
||||||
|
|
||||||
|
@JsonProperty("schemaLinks")
|
||||||
|
private String schemaLinks;
|
||||||
|
|
||||||
|
@JsonProperty("sql")
|
||||||
|
private String sql;
|
||||||
|
}
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.sql.llm.prompt;
|
||||||
|
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.type.TypeReference;
|
||||||
|
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.springframework.core.io.ClassPathResource;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Component
|
||||||
|
public class SqlExampleLoader {
|
||||||
|
|
||||||
|
private static final String EXAMPLE_JSON_FILE = "example.json";
|
||||||
|
|
||||||
|
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||||
|
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
|
||||||
|
};
|
||||||
|
|
||||||
|
public List<SqlExample> getSqlExamples() throws IOException {
|
||||||
|
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
||||||
|
InputStream inputStream = resource.getInputStream();
|
||||||
|
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addEmbeddingStore(List<SqlExample> sqlExamples, String collectionName) {
|
||||||
|
List<EmbeddingQuery> queries = new ArrayList<>();
|
||||||
|
for (int i = 0; i < sqlExamples.size(); i++) {
|
||||||
|
SqlExample sqlExample = sqlExamples.get(i);
|
||||||
|
String question = sqlExample.getQuestion();
|
||||||
|
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(sqlExample), String.class, Object.class);
|
||||||
|
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||||
|
embeddingQuery.setQueryId(String.valueOf(i));
|
||||||
|
embeddingQuery.setQuery(question);
|
||||||
|
embeddingQuery.setMetadata(metaDataMap);
|
||||||
|
queries.add(embeddingQuery);
|
||||||
|
}
|
||||||
|
s2EmbeddingStore.addQuery(collectionName, queries);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Map<String, String>> retrieverSqlExamples(String queryText, String collectionName, int maxResults) {
|
||||||
|
|
||||||
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
|
||||||
|
.queryEmbeddings(null).build();
|
||||||
|
|
||||||
|
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(collectionName, retrieveQuery,
|
||||||
|
maxResults);
|
||||||
|
List<Map<String, String>> result = new ArrayList<>();
|
||||||
|
if (CollectionUtils.isEmpty(resultList)) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
for (Retrieval retrieval : resultList.get(0).getRetrieval()) {
|
||||||
|
if (Objects.nonNull(retrieval.getMetadata()) && !retrieval.getMetadata().isEmpty()) {
|
||||||
|
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
|
||||||
|
.collect(Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue())));
|
||||||
|
result.add(convertedMap);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.sql.llm.prompt;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
@Component
|
||||||
|
@Slf4j
|
||||||
|
public class SqlPromptGenerator {
|
||||||
|
|
||||||
|
public String generateSchemaLinkingPrompt(String question, String modelName, List<String> fieldsList,
|
||||||
|
List<ElementValue> priorSchemaLinks, List<Map<String, String>> exampleList) {
|
||||||
|
|
||||||
|
String exampleTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n"
|
||||||
|
+ "问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}";
|
||||||
|
|
||||||
|
List<String> exampleKeys = Arrays.asList("tableName", "fieldsList", "priorSchemaLinks", "question", "analysis",
|
||||||
|
"schemaLinks");
|
||||||
|
|
||||||
|
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||||
|
|
||||||
|
String newCaseTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n"
|
||||||
|
+ "问题:{question}\n分析: 让我们一步一步地思考。";
|
||||||
|
|
||||||
|
String newCasePrompt = newCaseTemplate.replace("{tableName}", modelName)
|
||||||
|
.replace("{fieldsList}", fieldsList.toString())
|
||||||
|
.replace("{priorSchemaLinks}", getPriorSchemaLinks(priorSchemaLinks))
|
||||||
|
.replace("{question}", question);
|
||||||
|
|
||||||
|
String instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links";
|
||||||
|
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getPriorSchemaLinks(List<ElementValue> priorSchemaLinks) {
|
||||||
|
return priorSchemaLinks.stream()
|
||||||
|
.map(elementValue -> "'" + elementValue.getFieldName() + "'->" + elementValue.getFieldValue())
|
||||||
|
.collect(Collectors.joining(",", "[", "]"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public String generateSqlPrompt(String question, String modelName, String schemaLinkStr, String dataDate,
|
||||||
|
List<Map<String, String>> exampleList) {
|
||||||
|
|
||||||
|
List<String> exampleKeys = Arrays.asList("question", "currentDate", "tableName", "schemaLinks", "sql");
|
||||||
|
String exampleTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n"
|
||||||
|
+ "Schema_links:{schemaLinks}\nSQL:{sql}";
|
||||||
|
|
||||||
|
String sqlExamplePrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||||
|
|
||||||
|
String newCaseTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n"
|
||||||
|
+ "Schema_links:{schemaLinks}\nSQL:";
|
||||||
|
|
||||||
|
String newCasePrompt = newCaseTemplate.replace("{question}", question)
|
||||||
|
.replace("{currentDate}", dataDate)
|
||||||
|
.replace("{tableName}", modelName)
|
||||||
|
.replace("{schemaLinks}", schemaLinkStr);
|
||||||
|
|
||||||
|
String instruction = "# 根据schema_links为每个问题生成SQL查询语句";
|
||||||
|
return instruction + InputFormat.SEPERATOR + sqlExamplePrompt + InputFormat.SEPERATOR + newCasePrompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -19,12 +19,13 @@ import com.tencent.supersonic.chat.query.plugin.WebBase;
|
|||||||
import com.tencent.supersonic.chat.service.AgentService;
|
import com.tencent.supersonic.chat.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.service.PluginService;
|
import com.tencent.supersonic.chat.service.PluginService;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
|
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
@@ -39,7 +40,6 @@ import org.apache.commons.collections.CollectionUtils;
|
|||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.context.event.EventListener;
|
import org.springframework.context.event.EventListener;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import org.springframework.web.client.RestTemplate;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Component
|
@Component
|
||||||
@@ -47,14 +47,10 @@ public class PluginManager {
|
|||||||
|
|
||||||
private EmbeddingConfig embeddingConfig;
|
private EmbeddingConfig embeddingConfig;
|
||||||
|
|
||||||
private RestTemplate restTemplate;
|
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||||
|
|
||||||
private EmbeddingUtils embeddingUtils;
|
public PluginManager(EmbeddingConfig embeddingConfig) {
|
||||||
|
|
||||||
public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate, EmbeddingUtils embeddingUtils) {
|
|
||||||
this.embeddingConfig = embeddingConfig;
|
this.embeddingConfig = embeddingConfig;
|
||||||
this.restTemplate = restTemplate;
|
|
||||||
this.embeddingUtils = embeddingUtils;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<Plugin> getPluginAgentCanSupport(Integer agentId) {
|
public static List<Plugin> getPluginAgentCanSupport(Integer agentId) {
|
||||||
@@ -133,7 +129,7 @@ public class PluginManager {
|
|||||||
embeddingQuery.setQueryId(id);
|
embeddingQuery.setQueryId(id);
|
||||||
queries.add(embeddingQuery);
|
queries.add(embeddingQuery);
|
||||||
}
|
}
|
||||||
embeddingUtils.deleteQuery(presetCollection, queries);
|
s2EmbeddingStore.deleteQuery(presetCollection, queries);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
|
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
|
||||||
@@ -141,7 +137,7 @@ public class PluginManager {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
String presetCollection = embeddingConfig.getPresetCollection();
|
String presetCollection = embeddingConfig.getPresetCollection();
|
||||||
embeddingUtils.addQuery(presetCollection, queries);
|
s2EmbeddingStore.addQuery(presetCollection, queries);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
|
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
|
||||||
@@ -150,13 +146,12 @@ public class PluginManager {
|
|||||||
|
|
||||||
public RetrieveQueryResult recognize(String embeddingText) {
|
public RetrieveQueryResult recognize(String embeddingText) {
|
||||||
|
|
||||||
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
|
|
||||||
|
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||||
.queryTextsList(Collections.singletonList(embeddingText))
|
.queryTextsList(Collections.singletonList(embeddingText))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
List<RetrieveQueryResult> resultList = embeddingUtils.retrieveQuery(embeddingConfig.getPresetCollection(),
|
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(embeddingConfig.getPresetCollection(),
|
||||||
retrieveQuery, embeddingConfig.getNResult());
|
retrieveQuery, embeddingConfig.getNResult());
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,13 +7,12 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.common.pojo.QueryType;
|
import com.tencent.supersonic.common.pojo.QueryType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||||
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -21,6 +20,7 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* SimilarMetricQueryResponder fills recommended metrics based on embedding similarity.
|
* SimilarMetricQueryResponder fills recommended metrics based on embedding similarity.
|
||||||
@@ -29,6 +29,8 @@ public class SimilarMetricQueryResponder implements QueryResponder {
|
|||||||
|
|
||||||
private static final int METRIC_RECOMMEND_SIZE = 5;
|
private static final int METRIC_RECOMMEND_SIZE = 5;
|
||||||
|
|
||||||
|
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void fillInfo(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
|
public void fillInfo(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
|
||||||
fillSimilarMetric(queryResult.getChatContext());
|
fillSimilarMetric(queryResult.getChatContext());
|
||||||
@@ -46,8 +48,7 @@ public class SimilarMetricQueryResponder implements QueryResponder {
|
|||||||
filterCondition.put("type", SchemaElementType.METRIC.name());
|
filterCondition.put("type", SchemaElementType.METRIC.name());
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
|
||||||
.filterCondition(filterCondition).queryEmbeddings(null).build();
|
.filterCondition(filterCondition).queryEmbeddings(null).build();
|
||||||
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
|
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
|
||||||
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
|
|
||||||
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, METRIC_RECOMMEND_SIZE + 1);
|
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, METRIC_RECOMMEND_SIZE + 1);
|
||||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||||
return;
|
return;
|
||||||
@@ -66,7 +67,8 @@ public class SimilarMetricQueryResponder implements QueryResponder {
|
|||||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
|
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
|
||||||
SchemaElement.class);
|
SchemaElement.class);
|
||||||
if (retrieval.getMetadata().containsKey("modelId")) {
|
if (retrieval.getMetadata().containsKey("modelId")) {
|
||||||
schemaElement.setModel(Long.parseLong(retrieval.getMetadata().get("modelId")));
|
String modelId = retrieval.getMetadata().get("modelId").toString();
|
||||||
|
schemaElement.setModel(Long.parseLong(modelId));
|
||||||
}
|
}
|
||||||
schemaElement.setOrder(++metricOrder);
|
schemaElement.setOrder(++metricOrder);
|
||||||
parseInfo.getMetrics().add(schemaElement);
|
parseInfo.getMetrics().add(schemaElement);
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
package com.tencent.supersonic.chat.query.llm.analytics;
|
package com.tencent.supersonic.chat.query.llm.analytics;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||||
@@ -14,13 +13,11 @@ import com.tencent.supersonic.chat.query.QueryManager;
|
|||||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||||
import com.tencent.supersonic.common.pojo.QueryType;
|
import com.tencent.supersonic.common.pojo.QueryType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
|
||||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -31,8 +28,6 @@ import java.util.stream.Collectors;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.calcite.sql.parser.SqlParseException;
|
import org.apache.calcite.sql.parser.SqlParseException;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.http.HttpMethod;
|
|
||||||
import org.springframework.http.ResponseEntity;
|
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -43,7 +38,6 @@ public class MetricAnalyzeQuery extends LLMSemanticQuery {
|
|||||||
|
|
||||||
public static final String QUERY_MODE = "METRIC_INTERPRET";
|
public static final String QUERY_MODE = "METRIC_INTERPRET";
|
||||||
|
|
||||||
|
|
||||||
public MetricAnalyzeQuery() {
|
public MetricAnalyzeQuery() {
|
||||||
QueryManager.register(this);
|
QueryManager.register(this);
|
||||||
}
|
}
|
||||||
@@ -151,23 +145,7 @@ public class MetricAnalyzeQuery extends LLMSemanticQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public String fetchInterpret(String queryText, String dataText) {
|
public String fetchInterpret(String queryText, String dataText) {
|
||||||
LLMAnswerReq lLmAnswerReq = new LLMAnswerReq();
|
return "";
|
||||||
lLmAnswerReq.setQueryText(queryText);
|
|
||||||
lLmAnswerReq.setPluginOutput(dataText);
|
|
||||||
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
|
|
||||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
|
||||||
String metricAnalyzeQueryCollection = embeddingConfig.getMetricAnalyzeQueryCollection();
|
|
||||||
|
|
||||||
String url = String.format("%s/retrieve_query?collection_name=%s", embeddingConfig.getUrl(),
|
|
||||||
metricAnalyzeQueryCollection);
|
|
||||||
|
|
||||||
ResponseEntity<String> responseEntity = embeddingUtils.doRequest(url, JSONObject.toJSONString(lLmAnswerReq),
|
|
||||||
HttpMethod.POST);
|
|
||||||
LLMAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLMAnswerResp.class);
|
|
||||||
if (lLmAnswerResp != null) {
|
|
||||||
return lLmAnswerResp.getAssistantMessage();
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ import com.tencent.supersonic.chat.parser.LLMProxy;
|
|||||||
import com.tencent.supersonic.chat.parser.sql.llm.ModelResolver;
|
import com.tencent.supersonic.chat.parser.sql.llm.ModelResolver;
|
||||||
import com.tencent.supersonic.chat.processor.ParseResultProcessor;
|
import com.tencent.supersonic.chat.processor.ParseResultProcessor;
|
||||||
import com.tencent.supersonic.chat.query.QueryResponder;
|
import com.tencent.supersonic.chat.query.QueryResponder;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||||
|
|
||||||
public class ComponentFactory {
|
public class ComponentFactory {
|
||||||
|
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ import com.google.common.collect.Lists;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
|
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||||
import java.net.URI;
|
import java.net.URI;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
@@ -36,11 +37,11 @@ public class SolvedQueryManager {
|
|||||||
|
|
||||||
private EmbeddingConfig embeddingConfig;
|
private EmbeddingConfig embeddingConfig;
|
||||||
|
|
||||||
private EmbeddingUtils embeddingUtils;
|
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||||
|
|
||||||
public SolvedQueryManager(EmbeddingConfig embeddingConfig, EmbeddingUtils embeddingUtils) {
|
|
||||||
|
public SolvedQueryManager(EmbeddingConfig embeddingConfig) {
|
||||||
this.embeddingConfig = embeddingConfig;
|
this.embeddingConfig = embeddingConfig;
|
||||||
this.embeddingUtils = embeddingUtils;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) {
|
public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) {
|
||||||
@@ -54,12 +55,12 @@ public class SolvedQueryManager {
|
|||||||
embeddingQuery.setQueryId(uniqueId);
|
embeddingQuery.setQueryId(uniqueId);
|
||||||
embeddingQuery.setQuery(queryText);
|
embeddingQuery.setQuery(queryText);
|
||||||
|
|
||||||
Map<String, String> metaData = new HashMap<>();
|
Map<String, Object> metaData = new HashMap<>();
|
||||||
metaData.put("modelId", String.valueOf(solvedQueryReq.getModelId()));
|
metaData.put("modelId", (solvedQueryReq.getModelId()));
|
||||||
metaData.put("agentId", String.valueOf(solvedQueryReq.getAgentId()));
|
metaData.put("agentId", solvedQueryReq.getAgentId());
|
||||||
embeddingQuery.setMetadata(metaData);
|
embeddingQuery.setMetadata(metaData);
|
||||||
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||||
embeddingUtils.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
|
s2EmbeddingStore.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
|
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
|
||||||
}
|
}
|
||||||
@@ -80,7 +81,7 @@ public class SolvedQueryManager {
|
|||||||
.queryTextsList(Lists.newArrayList(queryText))
|
.queryTextsList(Lists.newArrayList(queryText))
|
||||||
.filterCondition(filterCondition)
|
.filterCondition(filterCondition)
|
||||||
.build();
|
.build();
|
||||||
List<RetrieveQueryResult> resultList = embeddingUtils.retrieveQuery(solvedQueryCollection, retrieveQuery,
|
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(solvedQueryCollection, retrieveQuery,
|
||||||
solvedQueryResultNum);
|
solvedQueryResultNum);
|
||||||
|
|
||||||
log.info("[embedding] recognize result body:{}", resultList);
|
log.info("[embedding] recognize result body:{}", resultList);
|
||||||
|
|||||||
@@ -161,6 +161,20 @@
|
|||||||
<artifactId>spring-boot-starter-web</artifactId>
|
<artifactId>spring-boot-starter-web</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
|
<!--langchain4j-->
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-open-ai</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-chroma</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
</project>
|
</project>
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ public class EmbeddingConfig {
|
|||||||
@Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}")
|
@Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}")
|
||||||
private String metricAnalyzeQueryCollection;
|
private String metricAnalyzeQueryCollection;
|
||||||
|
|
||||||
|
@Value("${embedding.metric.analyzeQuery.nResult:5}")
|
||||||
|
private int metricAnalyzeQueryResultNum;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package com.tencent.supersonic.common.util;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||||
|
import java.util.Objects;
|
||||||
|
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||||
|
|
||||||
|
public class ComponentFactory {
|
||||||
|
|
||||||
|
private static S2EmbeddingStore s2EmbeddingStore;
|
||||||
|
|
||||||
|
public static S2EmbeddingStore getS2EmbeddingStore() {
|
||||||
|
if (Objects.isNull(s2EmbeddingStore)) {
|
||||||
|
s2EmbeddingStore = init(S2EmbeddingStore.class);
|
||||||
|
}
|
||||||
|
return s2EmbeddingStore;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static <T> T init(Class<T> factoryType) {
|
||||||
|
return SpringFactoriesLoader.loadFactories(factoryType,
|
||||||
|
Thread.currentThread().getContextClassLoader()).get(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -13,7 +13,7 @@ public class EmbeddingQuery {
|
|||||||
|
|
||||||
private String query;
|
private String query;
|
||||||
|
|
||||||
private Map<String, String> metadata;
|
private Map<String, Object> metadata;
|
||||||
|
|
||||||
private List<Double> queryEmbedding;
|
private List<Double> queryEmbedding;
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,83 @@
|
|||||||
|
package com.tencent.supersonic.common.util.embedding;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
|
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
||||||
|
|
||||||
|
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
|
||||||
|
new ConcurrentHashMap<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void addCollection(String collectionName) {
|
||||||
|
collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
|
||||||
|
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = getEmbeddingStore(collectionName);
|
||||||
|
EmbeddingModel embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
|
||||||
|
for (EmbeddingQuery query : queries) {
|
||||||
|
String question = query.getQuery();
|
||||||
|
Embedding embedding = embeddingModel.embed(question).content();
|
||||||
|
embeddingStore.add(query.getQueryId(), embedding, query);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private InMemoryEmbeddingStore<EmbeddingQuery> getEmbeddingStore(String collectionName) {
|
||||||
|
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = collectionNameToStore.get(collectionName);
|
||||||
|
if (Objects.isNull(embeddingStore)) {
|
||||||
|
synchronized (InMemoryS2EmbeddingStore.class) {
|
||||||
|
addCollection(collectionName);
|
||||||
|
embeddingStore = collectionNameToStore.get(collectionName);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return embeddingStore;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
|
||||||
|
//not support in InMemoryEmbeddingStore
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||||
|
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = getEmbeddingStore(collectionName);
|
||||||
|
EmbeddingModel embeddingModel = ContextUtils.getBean(EmbeddingModel.class);
|
||||||
|
|
||||||
|
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||||
|
|
||||||
|
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||||
|
for (String queryText : queryTextsList) {
|
||||||
|
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
||||||
|
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, num);
|
||||||
|
|
||||||
|
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
||||||
|
retrieveQueryResult.setQuery(queryText);
|
||||||
|
List<Retrieval> retrievals = new ArrayList<>();
|
||||||
|
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
|
||||||
|
Retrieval retrieval = new Retrieval();
|
||||||
|
retrieval.setDistance(embeddingMatch.score());
|
||||||
|
retrieval.setId(embeddingMatch.embeddingId());
|
||||||
|
retrieval.setQuery(embeddingMatch.embedded().getQuery());
|
||||||
|
retrieval.setMetadata(embeddingMatch.embedded().getMetadata());
|
||||||
|
retrievals.add(retrieval);
|
||||||
|
}
|
||||||
|
retrieveQueryResult.setRetrieval(retrievals);
|
||||||
|
results.add(retrieveQueryResult);
|
||||||
|
}
|
||||||
|
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,10 @@ import com.alibaba.fastjson.JSONObject;
|
|||||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.core.ParameterizedTypeReference;
|
import org.springframework.core.ParameterizedTypeReference;
|
||||||
@@ -12,18 +16,12 @@ import org.springframework.http.HttpHeaders;
|
|||||||
import org.springframework.http.HttpMethod;
|
import org.springframework.http.HttpMethod;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
import org.springframework.http.ResponseEntity;
|
import org.springframework.http.ResponseEntity;
|
||||||
import org.springframework.stereotype.Component;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
import org.springframework.web.client.RestTemplate;
|
import org.springframework.web.client.RestTemplate;
|
||||||
import org.springframework.web.util.UriComponentsBuilder;
|
import org.springframework.web.util.UriComponentsBuilder;
|
||||||
import java.net.URI;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Component
|
public class PythonS2EmbeddingStore implements S2EmbeddingStore {
|
||||||
public class EmbeddingUtils {
|
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
private EmbeddingConfig embeddingConfig;
|
private EmbeddingConfig embeddingConfig;
|
||||||
@@ -103,6 +101,5 @@ public class EmbeddingUtils {
|
|||||||
}
|
}
|
||||||
return ResponseEntity.of(Optional.empty());
|
return ResponseEntity.of(Optional.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -15,7 +15,7 @@ public class Retrieval {
|
|||||||
|
|
||||||
protected String query;
|
protected String query;
|
||||||
|
|
||||||
protected Map<String, String> metadata;
|
protected Map<String, Object> metadata;
|
||||||
|
|
||||||
public static Long getLongId(String id) {
|
public static Long getLongId(String id) {
|
||||||
if (StringUtils.isBlank(id)) {
|
if (StringUtils.isBlank(id)) {
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package com.tencent.supersonic.common.util.embedding;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Supersonic EmbeddingStore
|
||||||
|
* Added the functionality of adding and querying collection names.
|
||||||
|
*/
|
||||||
|
public interface S2EmbeddingStore {
|
||||||
|
|
||||||
|
void addCollection(String collectionName);
|
||||||
|
|
||||||
|
void addQuery(String collectionName, List<EmbeddingQuery> queries);
|
||||||
|
|
||||||
|
void deleteQuery(String collectionName, List<EmbeddingQuery> queries);
|
||||||
|
|
||||||
|
List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -43,6 +43,16 @@
|
|||||||
<groupId>org.projectlombok</groupId>
|
<groupId>org.projectlombok</groupId>
|
||||||
<artifactId>lombok</artifactId>
|
<artifactId>lombok</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<!--langchain4j-->
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-spring-boot-starter</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
</project>
|
</project>
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package dev.langchain4j;
|
||||||
|
|
||||||
|
enum ModelProvider {
|
||||||
|
OPEN_AI,
|
||||||
|
HUGGING_FACE,
|
||||||
|
LOCAL_AI,
|
||||||
|
IN_MEMORY
|
||||||
|
}
|
||||||
@@ -0,0 +1,283 @@
|
|||||||
|
package dev.langchain4j;
|
||||||
|
|
||||||
|
import static dev.langchain4j.ModelProvider.OPEN_AI;
|
||||||
|
import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration;
|
||||||
|
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
||||||
|
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
||||||
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
|
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
|
||||||
|
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
|
||||||
|
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
|
||||||
|
import dev.langchain4j.model.language.LanguageModel;
|
||||||
|
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||||
|
import dev.langchain4j.model.localai.LocalAiEmbeddingModel;
|
||||||
|
import dev.langchain4j.model.localai.LocalAiLanguageModel;
|
||||||
|
import dev.langchain4j.model.moderation.ModerationModel;
|
||||||
|
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||||
|
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||||
|
import dev.langchain4j.model.openai.OpenAiLanguageModel;
|
||||||
|
import dev.langchain4j.model.openai.OpenAiModerationModel;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
|
||||||
|
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||||
|
import org.springframework.context.annotation.Bean;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
import org.springframework.context.annotation.Lazy;
|
||||||
|
import org.springframework.context.annotation.Primary;
|
||||||
|
|
||||||
|
@Configuration
|
||||||
|
@EnableConfigurationProperties(LangChain4jProperties.class)
|
||||||
|
public class S2LangChain4jAutoConfiguration {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private LangChain4jProperties properties;
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
@Lazy
|
||||||
|
@ConditionalOnMissingBean
|
||||||
|
ChatLanguageModel chatLanguageModel(LangChain4jProperties properties) {
|
||||||
|
if (properties.getChatModel() == null) {
|
||||||
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model' properties, for example:\n"
|
||||||
|
+ "langchain4j.chat-model.provider = openai\n"
|
||||||
|
+ "langchain4j.chat-model.openai.api-key = sk-...\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (properties.getChatModel().getProvider()) {
|
||||||
|
|
||||||
|
case OPEN_AI:
|
||||||
|
OpenAi openAi = properties.getChatModel().getOpenAi();
|
||||||
|
if (openAi == null || isNullOrBlank(openAi.getApiKey())) {
|
||||||
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model.openai.api-key' property");
|
||||||
|
}
|
||||||
|
return OpenAiChatModel.builder()
|
||||||
|
.baseUrl(openAi.getBaseUrl())
|
||||||
|
.apiKey(openAi.getApiKey())
|
||||||
|
.modelName(openAi.getModelName())
|
||||||
|
.temperature(openAi.getTemperature())
|
||||||
|
.topP(openAi.getTopP())
|
||||||
|
.maxTokens(openAi.getMaxTokens())
|
||||||
|
.presencePenalty(openAi.getPresencePenalty())
|
||||||
|
.frequencyPenalty(openAi.getFrequencyPenalty())
|
||||||
|
.timeout(openAi.getTimeout())
|
||||||
|
.maxRetries(openAi.getMaxRetries())
|
||||||
|
.logRequests(openAi.getLogRequests())
|
||||||
|
.logResponses(openAi.getLogResponses())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
case HUGGING_FACE:
|
||||||
|
HuggingFace huggingFace = properties.getChatModel().getHuggingFace();
|
||||||
|
if (huggingFace == null || isNullOrBlank(huggingFace.getAccessToken())) {
|
||||||
|
throw illegalConfiguration(
|
||||||
|
"\n\nPlease define 'langchain4j.chat-model.huggingface.access-token' property");
|
||||||
|
}
|
||||||
|
return HuggingFaceChatModel.builder()
|
||||||
|
.accessToken(huggingFace.getAccessToken())
|
||||||
|
.modelId(huggingFace.getModelId())
|
||||||
|
.timeout(huggingFace.getTimeout())
|
||||||
|
.temperature(huggingFace.getTemperature())
|
||||||
|
.maxNewTokens(huggingFace.getMaxNewTokens())
|
||||||
|
.returnFullText(huggingFace.getReturnFullText())
|
||||||
|
.waitForModel(huggingFace.getWaitForModel())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
case LOCAL_AI:
|
||||||
|
LocalAi localAi = properties.getChatModel().getLocalAi();
|
||||||
|
if (localAi == null || isNullOrBlank(localAi.getBaseUrl())) {
|
||||||
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model.localai.base-url' property");
|
||||||
|
}
|
||||||
|
if (isNullOrBlank(localAi.getModelName())) {
|
||||||
|
throw illegalConfiguration(
|
||||||
|
"\n\nPlease define 'langchain4j.chat-model.localai.model-name' property");
|
||||||
|
}
|
||||||
|
return LocalAiChatModel.builder()
|
||||||
|
.baseUrl(localAi.getBaseUrl())
|
||||||
|
.modelName(localAi.getModelName())
|
||||||
|
.temperature(localAi.getTemperature())
|
||||||
|
.topP(localAi.getTopP())
|
||||||
|
.maxTokens(localAi.getMaxTokens())
|
||||||
|
.timeout(localAi.getTimeout())
|
||||||
|
.maxRetries(localAi.getMaxRetries())
|
||||||
|
.logRequests(localAi.getLogRequests())
|
||||||
|
.logResponses(localAi.getLogResponses())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw illegalConfiguration("Unsupported chat model provider: %s",
|
||||||
|
properties.getChatModel().getProvider());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
@Lazy
|
||||||
|
@ConditionalOnMissingBean
|
||||||
|
LanguageModel languageModel(LangChain4jProperties properties) {
|
||||||
|
if (properties.getLanguageModel() == null) {
|
||||||
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.language-model' properties, for example:\n"
|
||||||
|
+ "langchain4j.language-model.provider = openai\n"
|
||||||
|
+ "langchain4j.language-model.openai.api-key = sk-...\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (properties.getLanguageModel().getProvider()) {
|
||||||
|
|
||||||
|
case OPEN_AI:
|
||||||
|
OpenAi openAi = properties.getLanguageModel().getOpenAi();
|
||||||
|
if (openAi == null || isNullOrBlank(openAi.getApiKey())) {
|
||||||
|
throw illegalConfiguration(
|
||||||
|
"\n\nPlease define 'langchain4j.language-model.openai.api-key' property");
|
||||||
|
}
|
||||||
|
return OpenAiLanguageModel.builder()
|
||||||
|
.apiKey(openAi.getApiKey())
|
||||||
|
.modelName(openAi.getModelName())
|
||||||
|
.temperature(openAi.getTemperature())
|
||||||
|
.timeout(openAi.getTimeout())
|
||||||
|
.maxRetries(openAi.getMaxRetries())
|
||||||
|
.logRequests(openAi.getLogRequests())
|
||||||
|
.logResponses(openAi.getLogResponses())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
case HUGGING_FACE:
|
||||||
|
HuggingFace huggingFace = properties.getLanguageModel().getHuggingFace();
|
||||||
|
if (huggingFace == null || isNullOrBlank(huggingFace.getAccessToken())) {
|
||||||
|
throw illegalConfiguration(
|
||||||
|
"\n\nPlease define 'langchain4j.language-model.huggingface.access-token' property");
|
||||||
|
}
|
||||||
|
return HuggingFaceLanguageModel.builder()
|
||||||
|
.accessToken(huggingFace.getAccessToken())
|
||||||
|
.modelId(huggingFace.getModelId())
|
||||||
|
.timeout(huggingFace.getTimeout())
|
||||||
|
.temperature(huggingFace.getTemperature())
|
||||||
|
.maxNewTokens(huggingFace.getMaxNewTokens())
|
||||||
|
.returnFullText(huggingFace.getReturnFullText())
|
||||||
|
.waitForModel(huggingFace.getWaitForModel())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
case LOCAL_AI:
|
||||||
|
LocalAi localAi = properties.getLanguageModel().getLocalAi();
|
||||||
|
if (localAi == null || isNullOrBlank(localAi.getBaseUrl())) {
|
||||||
|
throw illegalConfiguration(
|
||||||
|
"\n\nPlease define 'langchain4j.language-model.localai.base-url' property");
|
||||||
|
}
|
||||||
|
if (isNullOrBlank(localAi.getModelName())) {
|
||||||
|
throw illegalConfiguration(
|
||||||
|
"\n\nPlease define 'langchain4j.language-model.localai.model-name' property");
|
||||||
|
}
|
||||||
|
return LocalAiLanguageModel.builder()
|
||||||
|
.baseUrl(localAi.getBaseUrl())
|
||||||
|
.modelName(localAi.getModelName())
|
||||||
|
.temperature(localAi.getTemperature())
|
||||||
|
.topP(localAi.getTopP())
|
||||||
|
.maxTokens(localAi.getMaxTokens())
|
||||||
|
.timeout(localAi.getTimeout())
|
||||||
|
.maxRetries(localAi.getMaxRetries())
|
||||||
|
.logRequests(localAi.getLogRequests())
|
||||||
|
.logResponses(localAi.getLogResponses())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
default:
|
||||||
|
throw illegalConfiguration("Unsupported language model provider: %s",
|
||||||
|
properties.getLanguageModel().getProvider());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
@Lazy
|
||||||
|
@ConditionalOnMissingBean
|
||||||
|
@Primary
|
||||||
|
EmbeddingModel embeddingModel(LangChain4jProperties properties) {
|
||||||
|
if (properties.getEmbeddingModel() == null || properties.getEmbeddingModel().getProvider() == null) {
|
||||||
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.embedding-model' properties, for example:\n"
|
||||||
|
+ "langchain4j.embedding-model.provider = openai\n"
|
||||||
|
+ "langchain4j.embedding-model.openai.api-key = sk-...\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (properties.getEmbeddingModel().getProvider()) {
|
||||||
|
|
||||||
|
case OPEN_AI:
|
||||||
|
OpenAi openAi = properties.getEmbeddingModel().getOpenAi();
|
||||||
|
if (openAi == null || isNullOrBlank(openAi.getApiKey())) {
|
||||||
|
throw illegalConfiguration(
|
||||||
|
"\n\nPlease define 'langchain4j.embedding-model.openai.api-key' property");
|
||||||
|
}
|
||||||
|
return OpenAiEmbeddingModel.builder()
|
||||||
|
.apiKey(openAi.getApiKey())
|
||||||
|
.modelName(openAi.getModelName())
|
||||||
|
.timeout(openAi.getTimeout())
|
||||||
|
.maxRetries(openAi.getMaxRetries())
|
||||||
|
.logRequests(openAi.getLogRequests())
|
||||||
|
.logResponses(openAi.getLogResponses())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
case HUGGING_FACE:
|
||||||
|
HuggingFace huggingFace = properties.getEmbeddingModel().getHuggingFace();
|
||||||
|
if (huggingFace == null || isNullOrBlank(huggingFace.getAccessToken())) {
|
||||||
|
throw illegalConfiguration(
|
||||||
|
"\n\nPlease define 'langchain4j.embedding-model.huggingface.access-token' property");
|
||||||
|
}
|
||||||
|
return HuggingFaceEmbeddingModel.builder()
|
||||||
|
.accessToken(huggingFace.getAccessToken())
|
||||||
|
.modelId(huggingFace.getModelId())
|
||||||
|
.waitForModel(huggingFace.getWaitForModel())
|
||||||
|
.timeout(huggingFace.getTimeout())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
case LOCAL_AI:
|
||||||
|
LocalAi localAi = properties.getEmbeddingModel().getLocalAi();
|
||||||
|
if (localAi == null || isNullOrBlank(localAi.getBaseUrl())) {
|
||||||
|
throw illegalConfiguration(
|
||||||
|
"\n\nPlease define 'langchain4j.embedding-model.localai.base-url' property");
|
||||||
|
}
|
||||||
|
if (isNullOrBlank(localAi.getModelName())) {
|
||||||
|
throw illegalConfiguration(
|
||||||
|
"\n\nPlease define 'langchain4j.embedding-model.localai.model-name' property");
|
||||||
|
}
|
||||||
|
return LocalAiEmbeddingModel.builder()
|
||||||
|
.baseUrl(localAi.getBaseUrl())
|
||||||
|
.modelName(localAi.getModelName())
|
||||||
|
.timeout(localAi.getTimeout())
|
||||||
|
.maxRetries(localAi.getMaxRetries())
|
||||||
|
.logRequests(localAi.getLogRequests())
|
||||||
|
.logResponses(localAi.getLogResponses())
|
||||||
|
.build();
|
||||||
|
case IN_MEMORY:
|
||||||
|
return new AllMiniLmL6V2EmbeddingModel();
|
||||||
|
default:
|
||||||
|
throw illegalConfiguration("Unsupported embedding model provider: %s",
|
||||||
|
properties.getEmbeddingModel().getProvider());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
@Lazy
|
||||||
|
@ConditionalOnMissingBean
|
||||||
|
ModerationModel moderationModel(LangChain4jProperties properties) {
|
||||||
|
if (properties.getModerationModel() == null) {
|
||||||
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model' properties, for example:\n"
|
||||||
|
+ "langchain4j.moderation-model.provider = openai\n"
|
||||||
|
+ "langchain4j.moderation-model.openai.api-key = sk-...\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (properties.getModerationModel().getProvider() != OPEN_AI) {
|
||||||
|
throw illegalConfiguration("Unsupported moderation model provider: %s",
|
||||||
|
properties.getModerationModel().getProvider());
|
||||||
|
}
|
||||||
|
|
||||||
|
OpenAi openAi = properties.getModerationModel().getOpenAi();
|
||||||
|
if (openAi == null || isNullOrBlank(openAi.getApiKey())) {
|
||||||
|
throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model.openai.api-key' property");
|
||||||
|
}
|
||||||
|
|
||||||
|
return OpenAiModerationModel.builder()
|
||||||
|
.apiKey(openAi.getApiKey())
|
||||||
|
.modelName(openAi.getModelName())
|
||||||
|
.timeout(openAi.getTimeout())
|
||||||
|
.maxRetries(openAi.getMaxRetries())
|
||||||
|
.logRequests(openAi.getLogRequests())
|
||||||
|
.logResponses(openAi.getLogResponses())
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package com.tencent.supersonic;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExample;
|
||||||
|
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader;
|
||||||
|
import com.tencent.supersonic.chat.parser.EmbedLLMProxy;
|
||||||
|
import com.tencent.supersonic.chat.parser.LLMProxy;
|
||||||
|
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.boot.CommandLineRunner;
|
||||||
|
import org.springframework.core.annotation.Order;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Component
|
||||||
|
@Order(4)
|
||||||
|
public class EmbeddingInitListener implements CommandLineRunner {
|
||||||
|
|
||||||
|
protected LLMProxy llmProxy = ComponentFactory.getLLMProxy();
|
||||||
|
@Autowired
|
||||||
|
private SqlExampleLoader sqlExampleLoader;
|
||||||
|
@Autowired
|
||||||
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run(String... args) {
|
||||||
|
initSqlExamples();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void initSqlExamples() {
|
||||||
|
try {
|
||||||
|
if (llmProxy instanceof EmbedLLMProxy) {
|
||||||
|
List<SqlExample> sqlExamples = sqlExampleLoader.getSqlExamples();
|
||||||
|
String collectionName = optimizationConfig.getText2sqlCollectionName();
|
||||||
|
sqlExampleLoader.addEmbeddingStore(sqlExamples, collectionName);
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("initSqlExamples error", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,9 +1,11 @@
|
|||||||
package com.tencent.supersonic;
|
package com.tencent.supersonic;
|
||||||
|
|
||||||
|
import dev.langchain4j.S2LangChain4jAutoConfiguration;
|
||||||
import org.springframework.boot.SpringApplication;
|
import org.springframework.boot.SpringApplication;
|
||||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||||
import org.springframework.boot.autoconfigure.data.mongo.MongoDataAutoConfiguration;
|
import org.springframework.boot.autoconfigure.data.mongo.MongoDataAutoConfiguration;
|
||||||
import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration;
|
import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration;
|
||||||
|
import org.springframework.context.annotation.Import;
|
||||||
import org.springframework.scheduling.annotation.EnableAsync;
|
import org.springframework.scheduling.annotation.EnableAsync;
|
||||||
import org.springframework.scheduling.annotation.EnableScheduling;
|
import org.springframework.scheduling.annotation.EnableScheduling;
|
||||||
|
|
||||||
@@ -11,6 +13,7 @@ import org.springframework.scheduling.annotation.EnableScheduling;
|
|||||||
exclude = {MongoAutoConfiguration.class, MongoDataAutoConfiguration.class})
|
exclude = {MongoAutoConfiguration.class, MongoDataAutoConfiguration.class})
|
||||||
@EnableScheduling
|
@EnableScheduling
|
||||||
@EnableAsync
|
@EnableAsync
|
||||||
|
@Import(S2LangChain4jAutoConfiguration.class)
|
||||||
public class StandaloneLauncher {
|
public class StandaloneLauncher {
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ com.tencent.supersonic.chat.processor.ParseResultProcessor=\
|
|||||||
com.tencent.supersonic.chat.processor.RespBuildProcessor
|
com.tencent.supersonic.chat.processor.RespBuildProcessor
|
||||||
|
|
||||||
com.tencent.supersonic.chat.parser.LLMProxy=\
|
com.tencent.supersonic.chat.parser.LLMProxy=\
|
||||||
com.tencent.supersonic.chat.parser.PythonLLMProxy
|
com.tencent.supersonic.chat.parser.EmbedLLMProxy
|
||||||
|
|
||||||
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
||||||
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter
|
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter
|
||||||
@@ -47,3 +47,6 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
|
|||||||
|
|
||||||
com.tencent.supersonic.chat.query.QueryResponder=\
|
com.tencent.supersonic.chat.query.QueryResponder=\
|
||||||
com.tencent.supersonic.chat.query.SimilarMetricQueryResponder
|
com.tencent.supersonic.chat.query.SimilarMetricQueryResponder
|
||||||
|
|
||||||
|
com.tencent.supersonic.common.util.embedding.S2EmbeddingStore=\
|
||||||
|
com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore
|
||||||
@@ -36,8 +36,40 @@ mybatis:
|
|||||||
|
|
||||||
llm:
|
llm:
|
||||||
parser:
|
parser:
|
||||||
url: http://127.0.0.1:9092
|
url:
|
||||||
embedding:
|
embedding:
|
||||||
url: http://127.0.0.1:9092
|
url: http://127.0.0.1:9092
|
||||||
functionCall:
|
functionCall:
|
||||||
url: http://127.0.0.1:9092
|
url: http://127.0.0.1:9092
|
||||||
|
|
||||||
|
#langchain4j config
|
||||||
|
langchain4j:
|
||||||
|
#1.chat-model
|
||||||
|
chat-model:
|
||||||
|
provider: open_ai
|
||||||
|
openai:
|
||||||
|
api-key: api_key
|
||||||
|
model-name: gpt-3.5-turbo
|
||||||
|
temperature: 0.0
|
||||||
|
timeout: PT60S
|
||||||
|
#2.embedding-model
|
||||||
|
embedding-model:
|
||||||
|
provider: in_memory
|
||||||
|
# embedding-model:
|
||||||
|
# hugging-face:
|
||||||
|
# access-token: hg_access_token
|
||||||
|
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
||||||
|
# timeout: 1h
|
||||||
|
|
||||||
|
# embedding-model:
|
||||||
|
# provider: open_ai
|
||||||
|
# openai:
|
||||||
|
# api-key: api_key
|
||||||
|
# modelName: all-minilm-l6-v2.onnx
|
||||||
|
|
||||||
|
|
||||||
|
#langchain4j log
|
||||||
|
logging:
|
||||||
|
level:
|
||||||
|
dev.langchain4j: DEBUG
|
||||||
|
dev.ai4j.openai4j: DEBUG
|
||||||
@@ -1,16 +1,13 @@
|
|||||||
package com.tencent.supersonic.integration;
|
package com.tencent.supersonic.integration;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
|
||||||
import com.tencent.supersonic.StandaloneLauncher;
|
import com.tencent.supersonic.StandaloneLauncher;
|
||||||
import com.tencent.supersonic.chat.query.llm.analytics.LLMAnswerResp;
|
import com.tencent.supersonic.chat.query.llm.analytics.LLMAnswerResp;
|
||||||
import com.tencent.supersonic.chat.service.AgentService;
|
import com.tencent.supersonic.chat.service.AgentService;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.springframework.boot.test.context.SpringBootTest;
|
import org.springframework.boot.test.context.SpringBootTest;
|
||||||
import org.springframework.boot.test.mock.mockito.MockBean;
|
import org.springframework.boot.test.mock.mockito.MockBean;
|
||||||
import org.springframework.http.ResponseEntity;
|
|
||||||
import org.springframework.test.context.ActiveProfiles;
|
import org.springframework.test.context.ActiveProfiles;
|
||||||
import org.springframework.test.context.junit4.SpringRunner;
|
import org.springframework.test.context.junit4.SpringRunner;
|
||||||
|
|
||||||
@@ -21,13 +18,9 @@ public class MetricInterpretTest {
|
|||||||
|
|
||||||
@MockBean
|
@MockBean
|
||||||
private AgentService agentService;
|
private AgentService agentService;
|
||||||
|
|
||||||
@MockBean
|
@MockBean
|
||||||
private EmbeddingConfig embeddingConfig;
|
private EmbeddingConfig embeddingConfig;
|
||||||
|
|
||||||
@MockBean
|
|
||||||
private EmbeddingUtils embeddingUtils;
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMetricInterpret() throws Exception {
|
public void testMetricInterpret() throws Exception {
|
||||||
MockConfiguration.mockAgent(agentService);
|
MockConfiguration.mockAgent(agentService);
|
||||||
@@ -36,7 +29,6 @@ public class MetricInterpretTest {
|
|||||||
LLMAnswerResp lLmAnswerResp = new LLMAnswerResp();
|
LLMAnswerResp lLmAnswerResp = new LLMAnswerResp();
|
||||||
lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多");
|
lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多");
|
||||||
|
|
||||||
MockConfiguration.embeddingUtils(embeddingUtils, ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,17 @@
|
|||||||
package com.tencent.supersonic.integration;
|
package com.tencent.supersonic.integration;
|
||||||
|
|
||||||
|
|
||||||
import static org.mockito.ArgumentMatchers.anyObject;
|
|
||||||
import static org.mockito.Mockito.when;
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||||
import com.tencent.supersonic.chat.service.AgentService;
|
import com.tencent.supersonic.chat.service.AgentService;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
import com.tencent.supersonic.util.DataUtils;
|
import com.tencent.supersonic.util.DataUtils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.http.ResponseEntity;
|
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -38,7 +35,4 @@ public class MockConfiguration {
|
|||||||
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
|
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void embeddingUtils(EmbeddingUtils embeddingUtils, ResponseEntity<String> responseEntity) {
|
|
||||||
when(embeddingUtils.doRequest(anyObject(), anyObject(), anyObject())).thenReturn(responseEntity);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
52
pom.xml
52
pom.xml
@@ -71,6 +71,7 @@
|
|||||||
<spotless.python.black.version>22.3.0</spotless.python.black.version>
|
<spotless.python.black.version>22.3.0</spotless.python.black.version>
|
||||||
<easyexcel.version>2.2.6</easyexcel.version>
|
<easyexcel.version>2.2.6</easyexcel.version>
|
||||||
<poi.version>3.17</poi.version>
|
<poi.version>3.17</poi.version>
|
||||||
|
<langchain4j.version>0.24.0</langchain4j.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<dependencyManagement>
|
<dependencyManagement>
|
||||||
@@ -94,6 +95,57 @@
|
|||||||
<artifactId>guava</artifactId>
|
<artifactId>guava</artifactId>
|
||||||
<version>${guava.version}</version>
|
<version>${guava.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<!--langchain4j-->
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-parent</artifactId>
|
||||||
|
<version>${langchain4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j</artifactId>
|
||||||
|
<version>${langchain4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-core</artifactId>
|
||||||
|
<version>${langchain4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-spring-boot-starter</artifactId>
|
||||||
|
<version>${langchain4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-open-ai</artifactId>
|
||||||
|
<version>${langchain4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-hugging-face</artifactId>
|
||||||
|
<version>${langchain4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-chroma</artifactId>
|
||||||
|
<version>${langchain4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-embeddings</artifactId>
|
||||||
|
<version>${langchain4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-hugging-face</artifactId>
|
||||||
|
<version>${langchain4j.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
|
||||||
|
<version>${langchain4j.version}</version>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</dependencyManagement>
|
</dependencyManagement>
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import com.alibaba.fastjson.JSONObject;
|
|||||||
import com.tencent.supersonic.common.pojo.DataEvent;
|
import com.tencent.supersonic.common.pojo.DataEvent;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||||
|
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.context.ApplicationListener;
|
import org.springframework.context.ApplicationListener;
|
||||||
import org.springframework.scheduling.annotation.Async;
|
import org.springframework.scheduling.annotation.Async;
|
||||||
@@ -23,8 +23,7 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
|||||||
|
|
||||||
public static final String COLLECTION_NAME = "meta_collection";
|
public static final String COLLECTION_NAME = "meta_collection";
|
||||||
|
|
||||||
@Autowired
|
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||||
private EmbeddingUtils embeddingUtils;
|
|
||||||
|
|
||||||
@Value("${embedding.operation.sleep.time:3000}")
|
@Value("${embedding.operation.sleep.time:3000}")
|
||||||
private Integer embeddingOperationSleepTime;
|
private Integer embeddingOperationSleepTime;
|
||||||
@@ -55,14 +54,14 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
|||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
log.error("", e);
|
log.error("", e);
|
||||||
}
|
}
|
||||||
embeddingUtils.addCollection(COLLECTION_NAME);
|
s2EmbeddingStore.addCollection(COLLECTION_NAME);
|
||||||
if (event.getEventType().equals(EventType.ADD)) {
|
if (event.getEventType().equals(EventType.ADD)) {
|
||||||
embeddingUtils.addQuery(COLLECTION_NAME, embeddingQueries);
|
s2EmbeddingStore.addQuery(COLLECTION_NAME, embeddingQueries);
|
||||||
} else if (event.getEventType().equals(EventType.DELETE)) {
|
} else if (event.getEventType().equals(EventType.DELETE)) {
|
||||||
embeddingUtils.deleteQuery(COLLECTION_NAME, embeddingQueries);
|
s2EmbeddingStore.deleteQuery(COLLECTION_NAME, embeddingQueries);
|
||||||
} else if (event.getEventType().equals(EventType.UPDATE)) {
|
} else if (event.getEventType().equals(EventType.UPDATE)) {
|
||||||
embeddingUtils.deleteQuery(COLLECTION_NAME, embeddingQueries);
|
s2EmbeddingStore.deleteQuery(COLLECTION_NAME, embeddingQueries);
|
||||||
embeddingUtils.addQuery(COLLECTION_NAME, embeddingQueries);
|
s2EmbeddingStore.addQuery(COLLECTION_NAME, embeddingQueries);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user