From abbe8c84a15e4b067f5ea8d949ed251132a294ad Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Fri, 8 Dec 2023 19:24:58 +0800 Subject: [PATCH] [improvement](python) LLM related services support Java service invocation (#484) --- chat/core/pom.xml | 13 + .../chat/mapper/EmbeddingMatchStrategy.java | 11 +- .../supersonic/chat/parser/EmbedLLMProxy.java | 84 ++++++ .../function/FunctionCallPromptGenerator.java | 44 +++ .../parser/sql/llm/prompt/InputFormat.java | 42 +++ .../parser/sql/llm/prompt/OutputFormat.java | 53 ++++ .../parser/sql/llm/prompt/SqlExample.java | 32 ++ .../sql/llm/prompt/SqlExampleLoader.java | 76 +++++ .../sql/llm/prompt/SqlPromptGenerator.java | 65 ++++ .../supersonic/chat/plugin/PluginManager.java | 19 +- .../query/SimilarMetricQueryResponder.java | 14 +- .../llm/analytics/MetricAnalyzeQuery.java | 24 +- .../chat/utils/ComponentFactory.java | 4 +- .../chat/utils/SolvedQueryManager.java | 19 +- common/pom.xml | 14 + .../common/config/EmbeddingConfig.java | 3 +- .../common/util/ComponentFactory.java | 23 ++ .../common/util/embedding/EmbeddingQuery.java | 2 +- .../embedding/InMemoryS2EmbeddingStore.java | 83 +++++ ...Utils.java => PythonS2EmbeddingStore.java} | 13 +- .../common/util/embedding/Retrieval.java | 2 +- .../util/embedding/S2EmbeddingStore.java | 19 ++ launchers/common/pom.xml | 10 + .../java/dev/langchain4j/ModelProvider.java | 8 + .../S2LangChain4jAutoConfiguration.java | 283 ++++++++++++++++++ .../supersonic/EmbeddingInitListener.java | 43 +++ .../supersonic/StandaloneLauncher.java | 3 + .../main/resources/META-INF/spring.factories | 7 +- .../src/main/resources/application-local.yaml | 36 ++- .../integration/MetricInterpretTest.java | 8 - .../integration/MockConfiguration.java | 6 - pom.xml | 52 ++++ .../listener/MetaEmbeddingListener.java | 17 +- 33 files changed, 1037 insertions(+), 95 deletions(-) create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/EmbedLLMProxy.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/function/FunctionCallPromptGenerator.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/InputFormat.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/OutputFormat.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlExample.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlExampleLoader.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlPromptGenerator.java create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/ComponentFactory.java create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java rename common/src/main/java/com/tencent/supersonic/common/util/embedding/{EmbeddingUtils.java => PythonS2EmbeddingStore.java} (98%) create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/embedding/S2EmbeddingStore.java create mode 100644 launchers/common/src/main/java/dev/langchain4j/ModelProvider.java create mode 100644 launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java create mode 100644 launchers/standalone/src/main/java/com/tencent/supersonic/EmbeddingInitListener.java diff --git a/chat/core/pom.xml b/chat/core/pom.xml index 24806171f..f169048c2 100644 --- a/chat/core/pom.xml +++ b/chat/core/pom.xml @@ -104,6 +104,19 @@ ${mockito-inline.version} test + + + dev.langchain4j + langchain4j-open-ai + + + dev.langchain4j + langchain4j + + + dev.langchain4j + langchain4j-chroma + diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java index 8163af681..739dadeab 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/EmbeddingMatchStrategy.java @@ -4,10 +4,11 @@ import com.google.common.collect.Lists; import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; +import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.embedding.Retrieval; import com.tencent.supersonic.common.util.embedding.RetrieveQuery; import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; +import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult; import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener; import java.util.Comparator; @@ -32,8 +33,8 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { @Autowired private OptimizationConfig optimizationConfig; - @Autowired - private EmbeddingUtils embeddingUtils; + + private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore(); @Override public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) { @@ -83,7 +84,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { .queryEmbeddings(null) .build(); // step2. retrieveQuery by detectSegment - List retrieveQueryResults = embeddingUtils.retrieveQuery( + List retrieveQueryResults = s2EmbeddingStore.retrieveQuery( MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber); if (CollectionUtils.isEmpty(retrieveQueryResults)) { @@ -97,7 +98,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue()); if (CollectionUtils.isNotEmpty(detectModelIds)) { retrievals.removeIf(retrieval -> { - String modelIdStr = retrieval.getMetadata().get("modelId"); + String modelIdStr = retrieval.getMetadata().get("modelId").toString(); if (StringUtils.isBlank(modelIdStr)) { return true; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/EmbedLLMProxy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/EmbedLLMProxy.java new file mode 100644 index 000000000..bb371c247 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/EmbedLLMProxy.java @@ -0,0 +1,84 @@ +package com.tencent.supersonic.chat.parser; + +import com.tencent.supersonic.chat.config.OptimizationConfig; +import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallPromptGenerator; +import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat; +import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader; +import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlPromptGenerator; +import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq; +import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.JsonUtil; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.model.output.Response; +import lombok.extern.slf4j.Slf4j; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +@Slf4j +public class EmbedLLMProxy implements LLMProxy { + + public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) { + + ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class); + + SqlExampleLoader sqlExampleLoader = ContextUtils.getBean(SqlExampleLoader.class); + + OptimizationConfig config = ContextUtils.getBean(OptimizationConfig.class); + + List> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(), + config.getText2sqlCollectionName(), config.getText2sqlFewShotsNum()); + + String queryText = llmReq.getQueryText(); + String modelName = llmReq.getSchema().getModelName(); + List fieldNameList = llmReq.getSchema().getFieldNameList(); + List linking = llmReq.getLinking(); + + SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class); + String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList, + linking, sqlExamples); + + Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>()); + Response linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage()); + + String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text()); + + String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr, + llmReq.getCurrentDate(), sqlExamples); + + Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>()); + Response sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage()); + + LLMResp result = new LLMResp(); + result.setQuery(queryText); + result.setSchemaLinkingOutput(linkingPromptStr); + result.setSchemaLinkStr(schemaLinkStr); + result.setModelName(modelName); + result.setSqlOutput(sqlResult.content().text()); + return result; + } + + @Override + public FunctionResp requestFunction(FunctionReq functionReq) { + + FunctionCallPromptGenerator promptGenerator = ContextUtils.getBean(FunctionCallPromptGenerator.class); + + String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(), + functionReq.getPluginConfigs()); + + ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class); + + String functionSelect = chatLanguageModel.generate(functionCallPrompt); + + return OutputFormat.functionCallParse(functionSelect); + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/function/FunctionCallPromptGenerator.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/function/FunctionCallPromptGenerator.java new file mode 100644 index 000000000..a1f21ae30 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/function/FunctionCallPromptGenerator.java @@ -0,0 +1,44 @@ +package com.tencent.supersonic.chat.parser.plugin.function; + +import com.tencent.supersonic.chat.parser.sql.llm.prompt.InputFormat; +import com.tencent.supersonic.chat.plugin.PluginParseConfig; +import java.util.List; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; + +@Component +@Slf4j +public class FunctionCallPromptGenerator { + + public String generateFunctionCallPrompt(String queryText, List toolConfigList) { + List toolExplainList = toolConfigList.stream() + .map(this::constructPluginPrompt) + .collect(Collectors.toList()); + String functionList = String.join(InputFormat.SEPERATOR, toolExplainList); + return constructTaskPrompt(queryText, functionList); + } + + public String constructPluginPrompt(PluginParseConfig parseConfig) { + String toolName = parseConfig.getName(); + String toolDescription = parseConfig.getDescription(); + List toolExamples = parseConfig.getExamples(); + + StringBuilder prompt = new StringBuilder(); + prompt.append("【工具名称】\n").append(toolName).append("\n"); + prompt.append("【工具描述】\n").append(toolDescription).append("\n"); + prompt.append("【工具适用问题示例】\n"); + for (String example : toolExamples) { + prompt.append(example).append("\n"); + } + return prompt.toString(); + } + + public String constructTaskPrompt(String queryText, String functionList) { + String instruction = String.format("问题为:%s\n请根据问题和工具的描述,选择对应的工具,完成任务。" + + "请注意,只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据)," + + "并给出最终选择,输出格式为json,key为’分析过程‘, ’选择工具‘", queryText); + + return String.format("工具选择如下:\n\n%s\n\n【任务说明】\n%s", functionList, instruction); + } +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/InputFormat.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/InputFormat.java new file mode 100644 index 000000000..146e126d3 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/InputFormat.java @@ -0,0 +1,42 @@ +package com.tencent.supersonic.chat.parser.sql.llm.prompt; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class InputFormat { + + public static final String SEPERATOR = "\n\n"; + + public static String format(String template, List templateKey, + List> toFormatList) { + List result = new ArrayList<>(); + + for (Map formatItem : toFormatList) { + Map retrievalMeta = subDict(formatItem, templateKey); + result.add(format(template, retrievalMeta)); + } + + return String.join(SEPERATOR, result); + } + + public static String format(String input, Map replacements) { + for (Map.Entry entry : replacements.entrySet()) { + input = input.replace(entry.getKey(), entry.getValue()); + } + return input; + } + + private static Map subDict(Map dict, List keys) { + Map subDict = new HashMap<>(); + for (String key : keys) { + if (dict.containsKey(key)) { + subDict.put(key, dict.get(key)); + } + } + return subDict; + } +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/OutputFormat.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/OutputFormat.java new file mode 100644 index 000000000..ff09b79fb --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/OutputFormat.java @@ -0,0 +1,53 @@ +package com.tencent.supersonic.chat.parser.sql.llm.prompt; + +import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp; +import com.tencent.supersonic.common.util.JsonUtil; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import lombok.extern.slf4j.Slf4j; + +/*** + * output format + */ +@Slf4j +public class OutputFormat { + + public static final String PATTERN = "\\{[^{}]+\\}"; + + public static String schemaLinkParse(String schemaLinkOutput) { + try { + schemaLinkOutput = schemaLinkOutput.trim(); + String pattern = "Schema_links:(.*)"; + Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL); + Matcher matcher = regexPattern.matcher(schemaLinkOutput); + if (matcher.find()) { + schemaLinkOutput = matcher.group(1).trim(); + } else { + schemaLinkOutput = null; + } + } catch (Exception e) { + log.error("", e); + schemaLinkOutput = null; + } + return schemaLinkOutput; + } + + public static FunctionResp functionCallParse(String llmOutput) { + try { + String[] findResult = llmOutput.split(PATTERN); + String result = findResult[0].trim(); + + Map resultDict = JsonUtil.toMap(result, String.class, String.class); + log.info("result:{},resultDict:{}", result, resultDict); + + String selection = resultDict.get("选择工具"); + FunctionResp resp = new FunctionResp(); + resp.setToolSelection(selection); + return resp; + } catch (Exception e) { + log.error("", e); + return null; + } + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlExample.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlExample.java new file mode 100644 index 000000000..c2f8df143 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlExample.java @@ -0,0 +1,32 @@ +package com.tencent.supersonic.chat.parser.sql.llm.prompt; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Data; + +@Data +public class SqlExample { + + @JsonProperty("currentDate") + private String currentDate; + + @JsonProperty("tableName") + private String tableName; + + @JsonProperty("fieldsList") + private String fieldsList; + + @JsonProperty("question") + private String question; + + @JsonProperty("priorSchemaLinks") + private String priorSchemaLinks; + + @JsonProperty("analysis") + private String analysis; + + @JsonProperty("schemaLinks") + private String schemaLinks; + + @JsonProperty("sql") + private String sql; +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlExampleLoader.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlExampleLoader.java new file mode 100644 index 000000000..a8ec09ecb --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlExampleLoader.java @@ -0,0 +1,76 @@ +package com.tencent.supersonic.chat.parser.sql.llm.prompt; + + +import com.fasterxml.jackson.core.type.TypeReference; +import com.tencent.supersonic.common.util.ComponentFactory; +import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; +import com.tencent.supersonic.common.util.embedding.Retrieval; +import com.tencent.supersonic.common.util.embedding.RetrieveQuery; +import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; +import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; +import java.io.IOException; +import java.io.InputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.CollectionUtils; +import org.springframework.core.io.ClassPathResource; +import org.springframework.stereotype.Component; + +@Slf4j +@Component +public class SqlExampleLoader { + + private static final String EXAMPLE_JSON_FILE = "example.json"; + + private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore(); + private TypeReference> valueTypeRef = new TypeReference>() { + }; + + public List getSqlExamples() throws IOException { + ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE); + InputStream inputStream = resource.getInputStream(); + return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef); + } + + public void addEmbeddingStore(List sqlExamples, String collectionName) { + List queries = new ArrayList<>(); + for (int i = 0; i < sqlExamples.size(); i++) { + SqlExample sqlExample = sqlExamples.get(i); + String question = sqlExample.getQuestion(); + Map metaDataMap = JsonUtil.toMap(JsonUtil.toString(sqlExample), String.class, Object.class); + EmbeddingQuery embeddingQuery = new EmbeddingQuery(); + embeddingQuery.setQueryId(String.valueOf(i)); + embeddingQuery.setQuery(question); + embeddingQuery.setMetadata(metaDataMap); + queries.add(embeddingQuery); + } + s2EmbeddingStore.addQuery(collectionName, queries); + } + + public List> retrieverSqlExamples(String queryText, String collectionName, int maxResults) { + + RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText)) + .queryEmbeddings(null).build(); + + List resultList = s2EmbeddingStore.retrieveQuery(collectionName, retrieveQuery, + maxResults); + List> result = new ArrayList<>(); + if (CollectionUtils.isEmpty(resultList)) { + return result; + } + for (Retrieval retrieval : resultList.get(0).getRetrieval()) { + if (Objects.nonNull(retrieval.getMetadata()) && !retrieval.getMetadata().isEmpty()) { + Map convertedMap = retrieval.getMetadata().entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue()))); + result.add(convertedMap); + } + } + return result; + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlPromptGenerator.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlPromptGenerator.java new file mode 100644 index 000000000..3552f06f4 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/prompt/SqlPromptGenerator.java @@ -0,0 +1,65 @@ +package com.tencent.supersonic.chat.parser.sql.llm.prompt; + +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; + +@Component +@Slf4j +public class SqlPromptGenerator { + + public String generateSchemaLinkingPrompt(String question, String modelName, List fieldsList, + List priorSchemaLinks, List> exampleList) { + + String exampleTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n" + + "问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}"; + + List exampleKeys = Arrays.asList("tableName", "fieldsList", "priorSchemaLinks", "question", "analysis", + "schemaLinks"); + + String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList); + + String newCaseTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n" + + "问题:{question}\n分析: 让我们一步一步地思考。"; + + String newCasePrompt = newCaseTemplate.replace("{tableName}", modelName) + .replace("{fieldsList}", fieldsList.toString()) + .replace("{priorSchemaLinks}", getPriorSchemaLinks(priorSchemaLinks)) + .replace("{question}", question); + + String instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links"; + return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt; + } + + private String getPriorSchemaLinks(List priorSchemaLinks) { + return priorSchemaLinks.stream() + .map(elementValue -> "'" + elementValue.getFieldName() + "'->" + elementValue.getFieldValue()) + .collect(Collectors.joining(",", "[", "]")); + } + + public String generateSqlPrompt(String question, String modelName, String schemaLinkStr, String dataDate, + List> exampleList) { + + List exampleKeys = Arrays.asList("question", "currentDate", "tableName", "schemaLinks", "sql"); + String exampleTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n" + + "Schema_links:{schemaLinks}\nSQL:{sql}"; + + String sqlExamplePrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList); + + String newCaseTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n" + + "Schema_links:{schemaLinks}\nSQL:"; + + String newCasePrompt = newCaseTemplate.replace("{question}", question) + .replace("{currentDate}", dataDate) + .replace("{tableName}", modelName) + .replace("{schemaLinks}", schemaLinkStr); + + String instruction = "# 根据schema_links为每个问题生成SQL查询语句"; + return instruction + InputFormat.SEPERATOR + sqlExamplePrompt + InputFormat.SEPERATOR + newCasePrompt; + } + +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java b/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java index 4c16ab4cb..c7aa2ad3e 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java @@ -19,12 +19,13 @@ import com.tencent.supersonic.chat.query.plugin.WebBase; import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.PluginService; import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; -import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; import com.tencent.supersonic.common.util.embedding.Retrieval; import com.tencent.supersonic.common.util.embedding.RetrieveQuery; import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; +import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -39,7 +40,6 @@ import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.context.event.EventListener; import org.springframework.stereotype.Component; -import org.springframework.web.client.RestTemplate; @Slf4j @Component @@ -47,14 +47,10 @@ public class PluginManager { private EmbeddingConfig embeddingConfig; - private RestTemplate restTemplate; + private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore(); - private EmbeddingUtils embeddingUtils; - - public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate, EmbeddingUtils embeddingUtils) { + public PluginManager(EmbeddingConfig embeddingConfig) { this.embeddingConfig = embeddingConfig; - this.restTemplate = restTemplate; - this.embeddingUtils = embeddingUtils; } public static List getPluginAgentCanSupport(Integer agentId) { @@ -133,7 +129,7 @@ public class PluginManager { embeddingQuery.setQueryId(id); queries.add(embeddingQuery); } - embeddingUtils.deleteQuery(presetCollection, queries); + s2EmbeddingStore.deleteQuery(presetCollection, queries); } public void requestEmbeddingPluginAdd(List queries) { @@ -141,7 +137,7 @@ public class PluginManager { return; } String presetCollection = embeddingConfig.getPresetCollection(); - embeddingUtils.addQuery(presetCollection, queries); + s2EmbeddingStore.addQuery(presetCollection, queries); } public void requestEmbeddingPluginAddALL(List plugins) { @@ -150,13 +146,12 @@ public class PluginManager { public RetrieveQueryResult recognize(String embeddingText) { - EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class); RetrieveQuery retrieveQuery = RetrieveQuery.builder() .queryTextsList(Collections.singletonList(embeddingText)) .build(); - List resultList = embeddingUtils.retrieveQuery(embeddingConfig.getPresetCollection(), + List resultList = s2EmbeddingStore.retrieveQuery(embeddingConfig.getPresetCollection(), retrieveQuery, embeddingConfig.getNResult()); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/SimilarMetricQueryResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/SimilarMetricQueryResponder.java index 3f3a0e616..0a0178ddc 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/SimilarMetricQueryResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/SimilarMetricQueryResponder.java @@ -7,13 +7,12 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.common.pojo.QueryType; -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; +import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.embedding.Retrieval; import com.tencent.supersonic.common.util.embedding.RetrieveQuery; import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; +import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener; -import org.springframework.util.CollectionUtils; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; @@ -21,6 +20,7 @@ import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import org.springframework.util.CollectionUtils; /** * SimilarMetricQueryResponder fills recommended metrics based on embedding similarity. @@ -29,6 +29,8 @@ public class SimilarMetricQueryResponder implements QueryResponder { private static final int METRIC_RECOMMEND_SIZE = 5; + private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore(); + @Override public void fillInfo(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) { fillSimilarMetric(queryResult.getChatContext()); @@ -46,8 +48,7 @@ public class SimilarMetricQueryResponder implements QueryResponder { filterCondition.put("type", SchemaElementType.METRIC.name()); RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames) .filterCondition(filterCondition).queryEmbeddings(null).build(); - EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class); - List retrieveQueryResults = embeddingUtils.retrieveQuery( + List retrieveQueryResults = s2EmbeddingStore.retrieveQuery( MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, METRIC_RECOMMEND_SIZE + 1); if (CollectionUtils.isEmpty(retrieveQueryResults)) { return; @@ -66,7 +67,8 @@ public class SimilarMetricQueryResponder implements QueryResponder { SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()), SchemaElement.class); if (retrieval.getMetadata().containsKey("modelId")) { - schemaElement.setModel(Long.parseLong(retrieval.getMetadata().get("modelId"))); + String modelId = retrieval.getMetadata().get("modelId").toString(); + schemaElement.setModel(Long.parseLong(modelId)); } schemaElement.setOrder(++metricOrder); parseInfo.getMetrics().add(schemaElement); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/analytics/MetricAnalyzeQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/analytics/MetricAnalyzeQuery.java index 581c3a02f..ac0b36b81 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/analytics/MetricAnalyzeQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/analytics/MetricAnalyzeQuery.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.chat.query.llm.analytics; -import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.component.SemanticInterpreter; @@ -14,13 +13,11 @@ import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.QueryReqBuilder; -import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.QueryType; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import java.util.HashMap; @@ -31,8 +28,6 @@ import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.commons.lang3.StringUtils; -import org.springframework.http.HttpMethod; -import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; @@ -43,7 +38,6 @@ public class MetricAnalyzeQuery extends LLMSemanticQuery { public static final String QUERY_MODE = "METRIC_INTERPRET"; - public MetricAnalyzeQuery() { QueryManager.register(this); } @@ -151,23 +145,7 @@ public class MetricAnalyzeQuery extends LLMSemanticQuery { } public String fetchInterpret(String queryText, String dataText) { - LLMAnswerReq lLmAnswerReq = new LLMAnswerReq(); - lLmAnswerReq.setQueryText(queryText); - lLmAnswerReq.setPluginOutput(dataText); - EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class); - EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); - String metricAnalyzeQueryCollection = embeddingConfig.getMetricAnalyzeQueryCollection(); - - String url = String.format("%s/retrieve_query?collection_name=%s", embeddingConfig.getUrl(), - metricAnalyzeQueryCollection); - - ResponseEntity responseEntity = embeddingUtils.doRequest(url, JSONObject.toJSONString(lLmAnswerReq), - HttpMethod.POST); - LLMAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLMAnswerResp.class); - if (lLmAnswerResp != null) { - return lLmAnswerResp.getAssistantMessage(); - } - return null; + return ""; } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java index 3ac663f83..744573d15 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java @@ -8,11 +8,11 @@ import com.tencent.supersonic.chat.parser.LLMProxy; import com.tencent.supersonic.chat.parser.sql.llm.ModelResolver; import com.tencent.supersonic.chat.processor.ParseResultProcessor; import com.tencent.supersonic.chat.query.QueryResponder; -import org.apache.commons.collections.CollectionUtils; -import org.springframework.core.io.support.SpringFactoriesLoader; import java.util.ArrayList; import java.util.List; import java.util.Objects; +import org.apache.commons.collections.CollectionUtils; +import org.springframework.core.io.support.SpringFactoriesLoader; public class ComponentFactory { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/SolvedQueryManager.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/SolvedQueryManager.java index 82817fb55..628e5db5a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/SolvedQueryManager.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/SolvedQueryManager.java @@ -4,11 +4,12 @@ import com.google.common.collect.Lists; import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq; import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; -import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; import com.tencent.supersonic.common.util.embedding.Retrieval; import com.tencent.supersonic.common.util.embedding.RetrieveQuery; import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; +import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; import java.net.URI; import java.util.HashMap; import java.util.HashSet; @@ -36,11 +37,11 @@ public class SolvedQueryManager { private EmbeddingConfig embeddingConfig; - private EmbeddingUtils embeddingUtils; + private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore(); - public SolvedQueryManager(EmbeddingConfig embeddingConfig, EmbeddingUtils embeddingUtils) { + + public SolvedQueryManager(EmbeddingConfig embeddingConfig) { this.embeddingConfig = embeddingConfig; - this.embeddingUtils = embeddingUtils; } public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) { @@ -54,12 +55,12 @@ public class SolvedQueryManager { embeddingQuery.setQueryId(uniqueId); embeddingQuery.setQuery(queryText); - Map metaData = new HashMap<>(); - metaData.put("modelId", String.valueOf(solvedQueryReq.getModelId())); - metaData.put("agentId", String.valueOf(solvedQueryReq.getAgentId())); + Map metaData = new HashMap<>(); + metaData.put("modelId", (solvedQueryReq.getModelId())); + metaData.put("agentId", solvedQueryReq.getAgentId()); embeddingQuery.setMetadata(metaData); String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection(); - embeddingUtils.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery)); + s2EmbeddingStore.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery)); } catch (Exception e) { log.warn("save history question to embedding failed, queryText:{}", queryText, e); } @@ -80,7 +81,7 @@ public class SolvedQueryManager { .queryTextsList(Lists.newArrayList(queryText)) .filterCondition(filterCondition) .build(); - List resultList = embeddingUtils.retrieveQuery(solvedQueryCollection, retrieveQuery, + List resultList = s2EmbeddingStore.retrieveQuery(solvedQueryCollection, retrieveQuery, solvedQueryResultNum); log.info("[embedding] recognize result body:{}", resultList); diff --git a/common/pom.xml b/common/pom.xml index 0a9bda594..0bd06bb8d 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -161,6 +161,20 @@ spring-boot-starter-web + + + dev.langchain4j + langchain4j-open-ai + + + dev.langchain4j + langchain4j + + + dev.langchain4j + langchain4j-chroma + + diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java index e4a4169f6..178117f31 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java @@ -29,6 +29,7 @@ public class EmbeddingConfig { @Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}") private String metricAnalyzeQueryCollection; - + @Value("${embedding.metric.analyzeQuery.nResult:5}") + private int metricAnalyzeQueryResultNum; } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/ComponentFactory.java b/common/src/main/java/com/tencent/supersonic/common/util/ComponentFactory.java new file mode 100644 index 000000000..768856ce1 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/ComponentFactory.java @@ -0,0 +1,23 @@ +package com.tencent.supersonic.common.util; + +import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; +import java.util.Objects; +import org.springframework.core.io.support.SpringFactoriesLoader; + +public class ComponentFactory { + + private static S2EmbeddingStore s2EmbeddingStore; + + public static S2EmbeddingStore getS2EmbeddingStore() { + if (Objects.isNull(s2EmbeddingStore)) { + s2EmbeddingStore = init(S2EmbeddingStore.class); + } + return s2EmbeddingStore; + } + + private static T init(Class factoryType) { + return SpringFactoriesLoader.loadFactories(factoryType, + Thread.currentThread().getContextClassLoader()).get(0); + } + +} diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingQuery.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingQuery.java index 46d6e17fe..f57448857 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingQuery.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingQuery.java @@ -13,7 +13,7 @@ public class EmbeddingQuery { private String query; - private Map metadata; + private Map metadata; private List queryEmbedding; diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java new file mode 100644 index 000000000..4f562c6b0 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java @@ -0,0 +1,83 @@ +package com.tencent.supersonic.common.util.embedding; + +import com.tencent.supersonic.common.util.ContextUtils; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { + + private static Map> collectionNameToStore = + new ConcurrentHashMap<>(); + + @Override + public void addCollection(String collectionName) { + collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore()); + } + + @Override + public void addQuery(String collectionName, List queries) { + InMemoryEmbeddingStore embeddingStore = getEmbeddingStore(collectionName); + EmbeddingModel embeddingModel = ContextUtils.getBean(EmbeddingModel.class); + for (EmbeddingQuery query : queries) { + String question = query.getQuery(); + Embedding embedding = embeddingModel.embed(question).content(); + embeddingStore.add(query.getQueryId(), embedding, query); + } + } + + private InMemoryEmbeddingStore getEmbeddingStore(String collectionName) { + InMemoryEmbeddingStore embeddingStore = collectionNameToStore.get(collectionName); + if (Objects.isNull(embeddingStore)) { + synchronized (InMemoryS2EmbeddingStore.class) { + addCollection(collectionName); + embeddingStore = collectionNameToStore.get(collectionName); + } + + } + return embeddingStore; + } + + @Override + public void deleteQuery(String collectionName, List queries) { + //not support in InMemoryEmbeddingStore + } + + @Override + public List retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) { + InMemoryEmbeddingStore embeddingStore = getEmbeddingStore(collectionName); + EmbeddingModel embeddingModel = ContextUtils.getBean(EmbeddingModel.class); + + List results = new ArrayList<>(); + + List queryTextsList = retrieveQuery.getQueryTextsList(); + for (String queryText : queryTextsList) { + Embedding embeddedText = embeddingModel.embed(queryText).content(); + List> relevant = embeddingStore.findRelevant(embeddedText, num); + + RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult(); + retrieveQueryResult.setQuery(queryText); + List retrievals = new ArrayList<>(); + for (EmbeddingMatch embeddingMatch : relevant) { + Retrieval retrieval = new Retrieval(); + retrieval.setDistance(embeddingMatch.score()); + retrieval.setId(embeddingMatch.embeddingId()); + retrieval.setQuery(embeddingMatch.embedded().getQuery()); + retrieval.setMetadata(embeddingMatch.embedded().getMetadata()); + retrievals.add(retrieval); + } + retrieveQueryResult.setRetrieval(retrievals); + results.add(retrieveQueryResult); + } + + return results; + } +} diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/PythonS2EmbeddingStore.java similarity index 98% rename from common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingUtils.java rename to common/src/main/java/com/tencent/supersonic/common/util/embedding/PythonS2EmbeddingStore.java index 90166ac65..fdb29c258 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/PythonS2EmbeddingStore.java @@ -4,6 +4,10 @@ import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.serializer.SerializerFeature; import com.google.common.collect.Lists; import com.tencent.supersonic.common.config.EmbeddingConfig; +import java.net.URI; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.core.ParameterizedTypeReference; @@ -12,18 +16,12 @@ import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; -import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriComponentsBuilder; -import java.net.URI; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; @Slf4j -@Component -public class EmbeddingUtils { +public class PythonS2EmbeddingStore implements S2EmbeddingStore { @Autowired private EmbeddingConfig embeddingConfig; @@ -103,6 +101,5 @@ public class EmbeddingUtils { } return ResponseEntity.of(Optional.empty()); } - } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/Retrieval.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/Retrieval.java index c1e02672c..1be07160e 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/Retrieval.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/Retrieval.java @@ -15,7 +15,7 @@ public class Retrieval { protected String query; - protected Map metadata; + protected Map metadata; public static Long getLongId(String id) { if (StringUtils.isBlank(id)) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/S2EmbeddingStore.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/S2EmbeddingStore.java new file mode 100644 index 000000000..1f8e92fc7 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/S2EmbeddingStore.java @@ -0,0 +1,19 @@ +package com.tencent.supersonic.common.util.embedding; + +import java.util.List; + +/** + * Supersonic EmbeddingStore + * Added the functionality of adding and querying collection names. + */ +public interface S2EmbeddingStore { + + void addCollection(String collectionName); + + void addQuery(String collectionName, List queries); + + void deleteQuery(String collectionName, List queries); + + List retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num); + +} diff --git a/launchers/common/pom.xml b/launchers/common/pom.xml index f0f3644d1..b954f3a39 100644 --- a/launchers/common/pom.xml +++ b/launchers/common/pom.xml @@ -43,6 +43,16 @@ org.projectlombok lombok + + + dev.langchain4j + langchain4j-spring-boot-starter + + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2 + \ No newline at end of file diff --git a/launchers/common/src/main/java/dev/langchain4j/ModelProvider.java b/launchers/common/src/main/java/dev/langchain4j/ModelProvider.java new file mode 100644 index 000000000..9c3953657 --- /dev/null +++ b/launchers/common/src/main/java/dev/langchain4j/ModelProvider.java @@ -0,0 +1,8 @@ +package dev.langchain4j; + +enum ModelProvider { + OPEN_AI, + HUGGING_FACE, + LOCAL_AI, + IN_MEMORY +} \ No newline at end of file diff --git a/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java b/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java new file mode 100644 index 000000000..e94264d17 --- /dev/null +++ b/launchers/common/src/main/java/dev/langchain4j/S2LangChain4jAutoConfiguration.java @@ -0,0 +1,283 @@ +package dev.langchain4j; + +import static dev.langchain4j.ModelProvider.OPEN_AI; +import static dev.langchain4j.exception.IllegalConfigurationException.illegalConfiguration; +import static dev.langchain4j.internal.Utils.isNullOrBlank; + +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.huggingface.HuggingFaceChatModel; +import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel; +import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel; +import dev.langchain4j.model.language.LanguageModel; +import dev.langchain4j.model.localai.LocalAiChatModel; +import dev.langchain4j.model.localai.LocalAiEmbeddingModel; +import dev.langchain4j.model.localai.LocalAiLanguageModel; +import dev.langchain4j.model.moderation.ModerationModel; +import dev.langchain4j.model.openai.OpenAiChatModel; +import dev.langchain4j.model.openai.OpenAiEmbeddingModel; +import dev.langchain4j.model.openai.OpenAiLanguageModel; +import dev.langchain4j.model.openai.OpenAiModerationModel; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Lazy; +import org.springframework.context.annotation.Primary; + +@Configuration +@EnableConfigurationProperties(LangChain4jProperties.class) +public class S2LangChain4jAutoConfiguration { + + @Autowired + private LangChain4jProperties properties; + + @Bean + @Lazy + @ConditionalOnMissingBean + ChatLanguageModel chatLanguageModel(LangChain4jProperties properties) { + if (properties.getChatModel() == null) { + throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model' properties, for example:\n" + + "langchain4j.chat-model.provider = openai\n" + + "langchain4j.chat-model.openai.api-key = sk-...\n"); + } + + switch (properties.getChatModel().getProvider()) { + + case OPEN_AI: + OpenAi openAi = properties.getChatModel().getOpenAi(); + if (openAi == null || isNullOrBlank(openAi.getApiKey())) { + throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model.openai.api-key' property"); + } + return OpenAiChatModel.builder() + .baseUrl(openAi.getBaseUrl()) + .apiKey(openAi.getApiKey()) + .modelName(openAi.getModelName()) + .temperature(openAi.getTemperature()) + .topP(openAi.getTopP()) + .maxTokens(openAi.getMaxTokens()) + .presencePenalty(openAi.getPresencePenalty()) + .frequencyPenalty(openAi.getFrequencyPenalty()) + .timeout(openAi.getTimeout()) + .maxRetries(openAi.getMaxRetries()) + .logRequests(openAi.getLogRequests()) + .logResponses(openAi.getLogResponses()) + .build(); + + case HUGGING_FACE: + HuggingFace huggingFace = properties.getChatModel().getHuggingFace(); + if (huggingFace == null || isNullOrBlank(huggingFace.getAccessToken())) { + throw illegalConfiguration( + "\n\nPlease define 'langchain4j.chat-model.huggingface.access-token' property"); + } + return HuggingFaceChatModel.builder() + .accessToken(huggingFace.getAccessToken()) + .modelId(huggingFace.getModelId()) + .timeout(huggingFace.getTimeout()) + .temperature(huggingFace.getTemperature()) + .maxNewTokens(huggingFace.getMaxNewTokens()) + .returnFullText(huggingFace.getReturnFullText()) + .waitForModel(huggingFace.getWaitForModel()) + .build(); + + case LOCAL_AI: + LocalAi localAi = properties.getChatModel().getLocalAi(); + if (localAi == null || isNullOrBlank(localAi.getBaseUrl())) { + throw illegalConfiguration("\n\nPlease define 'langchain4j.chat-model.localai.base-url' property"); + } + if (isNullOrBlank(localAi.getModelName())) { + throw illegalConfiguration( + "\n\nPlease define 'langchain4j.chat-model.localai.model-name' property"); + } + return LocalAiChatModel.builder() + .baseUrl(localAi.getBaseUrl()) + .modelName(localAi.getModelName()) + .temperature(localAi.getTemperature()) + .topP(localAi.getTopP()) + .maxTokens(localAi.getMaxTokens()) + .timeout(localAi.getTimeout()) + .maxRetries(localAi.getMaxRetries()) + .logRequests(localAi.getLogRequests()) + .logResponses(localAi.getLogResponses()) + .build(); + + default: + throw illegalConfiguration("Unsupported chat model provider: %s", + properties.getChatModel().getProvider()); + } + } + + @Bean + @Lazy + @ConditionalOnMissingBean + LanguageModel languageModel(LangChain4jProperties properties) { + if (properties.getLanguageModel() == null) { + throw illegalConfiguration("\n\nPlease define 'langchain4j.language-model' properties, for example:\n" + + "langchain4j.language-model.provider = openai\n" + + "langchain4j.language-model.openai.api-key = sk-...\n"); + } + + switch (properties.getLanguageModel().getProvider()) { + + case OPEN_AI: + OpenAi openAi = properties.getLanguageModel().getOpenAi(); + if (openAi == null || isNullOrBlank(openAi.getApiKey())) { + throw illegalConfiguration( + "\n\nPlease define 'langchain4j.language-model.openai.api-key' property"); + } + return OpenAiLanguageModel.builder() + .apiKey(openAi.getApiKey()) + .modelName(openAi.getModelName()) + .temperature(openAi.getTemperature()) + .timeout(openAi.getTimeout()) + .maxRetries(openAi.getMaxRetries()) + .logRequests(openAi.getLogRequests()) + .logResponses(openAi.getLogResponses()) + .build(); + + case HUGGING_FACE: + HuggingFace huggingFace = properties.getLanguageModel().getHuggingFace(); + if (huggingFace == null || isNullOrBlank(huggingFace.getAccessToken())) { + throw illegalConfiguration( + "\n\nPlease define 'langchain4j.language-model.huggingface.access-token' property"); + } + return HuggingFaceLanguageModel.builder() + .accessToken(huggingFace.getAccessToken()) + .modelId(huggingFace.getModelId()) + .timeout(huggingFace.getTimeout()) + .temperature(huggingFace.getTemperature()) + .maxNewTokens(huggingFace.getMaxNewTokens()) + .returnFullText(huggingFace.getReturnFullText()) + .waitForModel(huggingFace.getWaitForModel()) + .build(); + + case LOCAL_AI: + LocalAi localAi = properties.getLanguageModel().getLocalAi(); + if (localAi == null || isNullOrBlank(localAi.getBaseUrl())) { + throw illegalConfiguration( + "\n\nPlease define 'langchain4j.language-model.localai.base-url' property"); + } + if (isNullOrBlank(localAi.getModelName())) { + throw illegalConfiguration( + "\n\nPlease define 'langchain4j.language-model.localai.model-name' property"); + } + return LocalAiLanguageModel.builder() + .baseUrl(localAi.getBaseUrl()) + .modelName(localAi.getModelName()) + .temperature(localAi.getTemperature()) + .topP(localAi.getTopP()) + .maxTokens(localAi.getMaxTokens()) + .timeout(localAi.getTimeout()) + .maxRetries(localAi.getMaxRetries()) + .logRequests(localAi.getLogRequests()) + .logResponses(localAi.getLogResponses()) + .build(); + + default: + throw illegalConfiguration("Unsupported language model provider: %s", + properties.getLanguageModel().getProvider()); + } + } + + @Bean + @Lazy + @ConditionalOnMissingBean + @Primary + EmbeddingModel embeddingModel(LangChain4jProperties properties) { + if (properties.getEmbeddingModel() == null || properties.getEmbeddingModel().getProvider() == null) { + throw illegalConfiguration("\n\nPlease define 'langchain4j.embedding-model' properties, for example:\n" + + "langchain4j.embedding-model.provider = openai\n" + + "langchain4j.embedding-model.openai.api-key = sk-...\n"); + } + + switch (properties.getEmbeddingModel().getProvider()) { + + case OPEN_AI: + OpenAi openAi = properties.getEmbeddingModel().getOpenAi(); + if (openAi == null || isNullOrBlank(openAi.getApiKey())) { + throw illegalConfiguration( + "\n\nPlease define 'langchain4j.embedding-model.openai.api-key' property"); + } + return OpenAiEmbeddingModel.builder() + .apiKey(openAi.getApiKey()) + .modelName(openAi.getModelName()) + .timeout(openAi.getTimeout()) + .maxRetries(openAi.getMaxRetries()) + .logRequests(openAi.getLogRequests()) + .logResponses(openAi.getLogResponses()) + .build(); + + case HUGGING_FACE: + HuggingFace huggingFace = properties.getEmbeddingModel().getHuggingFace(); + if (huggingFace == null || isNullOrBlank(huggingFace.getAccessToken())) { + throw illegalConfiguration( + "\n\nPlease define 'langchain4j.embedding-model.huggingface.access-token' property"); + } + return HuggingFaceEmbeddingModel.builder() + .accessToken(huggingFace.getAccessToken()) + .modelId(huggingFace.getModelId()) + .waitForModel(huggingFace.getWaitForModel()) + .timeout(huggingFace.getTimeout()) + .build(); + + case LOCAL_AI: + LocalAi localAi = properties.getEmbeddingModel().getLocalAi(); + if (localAi == null || isNullOrBlank(localAi.getBaseUrl())) { + throw illegalConfiguration( + "\n\nPlease define 'langchain4j.embedding-model.localai.base-url' property"); + } + if (isNullOrBlank(localAi.getModelName())) { + throw illegalConfiguration( + "\n\nPlease define 'langchain4j.embedding-model.localai.model-name' property"); + } + return LocalAiEmbeddingModel.builder() + .baseUrl(localAi.getBaseUrl()) + .modelName(localAi.getModelName()) + .timeout(localAi.getTimeout()) + .maxRetries(localAi.getMaxRetries()) + .logRequests(localAi.getLogRequests()) + .logResponses(localAi.getLogResponses()) + .build(); + case IN_MEMORY: + return new AllMiniLmL6V2EmbeddingModel(); + default: + throw illegalConfiguration("Unsupported embedding model provider: %s", + properties.getEmbeddingModel().getProvider()); + + + } + } + + @Bean + @Lazy + @ConditionalOnMissingBean + ModerationModel moderationModel(LangChain4jProperties properties) { + if (properties.getModerationModel() == null) { + throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model' properties, for example:\n" + + "langchain4j.moderation-model.provider = openai\n" + + "langchain4j.moderation-model.openai.api-key = sk-...\n"); + } + + if (properties.getModerationModel().getProvider() != OPEN_AI) { + throw illegalConfiguration("Unsupported moderation model provider: %s", + properties.getModerationModel().getProvider()); + } + + OpenAi openAi = properties.getModerationModel().getOpenAi(); + if (openAi == null || isNullOrBlank(openAi.getApiKey())) { + throw illegalConfiguration("\n\nPlease define 'langchain4j.moderation-model.openai.api-key' property"); + } + + return OpenAiModerationModel.builder() + .apiKey(openAi.getApiKey()) + .modelName(openAi.getModelName()) + .timeout(openAi.getTimeout()) + .maxRetries(openAi.getMaxRetries()) + .logRequests(openAi.getLogRequests()) + .logResponses(openAi.getLogResponses()) + .build(); + } + +} \ No newline at end of file diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/EmbeddingInitListener.java b/launchers/standalone/src/main/java/com/tencent/supersonic/EmbeddingInitListener.java new file mode 100644 index 000000000..1a46e2261 --- /dev/null +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/EmbeddingInitListener.java @@ -0,0 +1,43 @@ +package com.tencent.supersonic; + +import com.tencent.supersonic.chat.config.OptimizationConfig; +import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExample; +import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader; +import com.tencent.supersonic.chat.parser.EmbedLLMProxy; +import com.tencent.supersonic.chat.parser.LLMProxy; +import com.tencent.supersonic.chat.utils.ComponentFactory; +import java.util.List; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.CommandLineRunner; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; + +@Slf4j +@Component +@Order(4) +public class EmbeddingInitListener implements CommandLineRunner { + + protected LLMProxy llmProxy = ComponentFactory.getLLMProxy(); + @Autowired + private SqlExampleLoader sqlExampleLoader; + @Autowired + private OptimizationConfig optimizationConfig; + + @Override + public void run(String... args) { + initSqlExamples(); + } + + public void initSqlExamples() { + try { + if (llmProxy instanceof EmbedLLMProxy) { + List sqlExamples = sqlExampleLoader.getSqlExamples(); + String collectionName = optimizationConfig.getText2sqlCollectionName(); + sqlExampleLoader.addEmbeddingStore(sqlExamples, collectionName); + } + } catch (Exception e) { + log.error("initSqlExamples error", e); + } + } +} diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/StandaloneLauncher.java b/launchers/standalone/src/main/java/com/tencent/supersonic/StandaloneLauncher.java index 3a8134f17..05a526c96 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/StandaloneLauncher.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/StandaloneLauncher.java @@ -1,9 +1,11 @@ package com.tencent.supersonic; +import dev.langchain4j.S2LangChain4jAutoConfiguration; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.data.mongo.MongoDataAutoConfiguration; import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration; +import org.springframework.context.annotation.Import; import org.springframework.scheduling.annotation.EnableAsync; import org.springframework.scheduling.annotation.EnableScheduling; @@ -11,6 +13,7 @@ import org.springframework.scheduling.annotation.EnableScheduling; exclude = {MongoAutoConfiguration.class, MongoDataAutoConfiguration.class}) @EnableScheduling @EnableAsync +@Import(S2LangChain4jAutoConfiguration.class) public class StandaloneLauncher { public static void main(String[] args) { diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index d639e8e52..35545c639 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -31,7 +31,7 @@ com.tencent.supersonic.chat.processor.ParseResultProcessor=\ com.tencent.supersonic.chat.processor.RespBuildProcessor com.tencent.supersonic.chat.parser.LLMProxy=\ - com.tencent.supersonic.chat.parser.PythonLLMProxy + com.tencent.supersonic.chat.parser.EmbedLLMProxy com.tencent.supersonic.chat.api.component.SemanticInterpreter=\ com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter @@ -46,4 +46,7 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor com.tencent.supersonic.chat.query.QueryResponder=\ - com.tencent.supersonic.chat.query.SimilarMetricQueryResponder \ No newline at end of file + com.tencent.supersonic.chat.query.SimilarMetricQueryResponder + +com.tencent.supersonic.common.util.embedding.S2EmbeddingStore=\ + com.tencent.supersonic.common.util.embedding.InMemoryS2EmbeddingStore \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/application-local.yaml b/launchers/standalone/src/main/resources/application-local.yaml index b4c829433..eb46b44b9 100644 --- a/launchers/standalone/src/main/resources/application-local.yaml +++ b/launchers/standalone/src/main/resources/application-local.yaml @@ -36,8 +36,40 @@ mybatis: llm: parser: - url: http://127.0.0.1:9092 + url: embedding: url: http://127.0.0.1:9092 functionCall: - url: http://127.0.0.1:9092 \ No newline at end of file + url: http://127.0.0.1:9092 + +#langchain4j config +langchain4j: + #1.chat-model + chat-model: + provider: open_ai + openai: + api-key: api_key + model-name: gpt-3.5-turbo + temperature: 0.0 + timeout: PT60S + #2.embedding-model + embedding-model: + provider: in_memory +# embedding-model: +# hugging-face: +# access-token: hg_access_token +# model-id: sentence-transformers/all-MiniLM-L6-v2 +# timeout: 1h + +# embedding-model: +# provider: open_ai +# openai: +# api-key: api_key +# modelName: all-minilm-l6-v2.onnx + + +#langchain4j log +logging: + level: + dev.langchain4j: DEBUG + dev.ai4j.openai4j: DEBUG \ No newline at end of file diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java index d2cbadb98..2a249da32 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java @@ -1,16 +1,13 @@ package com.tencent.supersonic.integration; -import com.alibaba.fastjson.JSONObject; import com.tencent.supersonic.StandaloneLauncher; import com.tencent.supersonic.chat.query.llm.analytics.LLMAnswerResp; import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.mock.mockito.MockBean; -import org.springframework.http.ResponseEntity; import org.springframework.test.context.ActiveProfiles; import org.springframework.test.context.junit4.SpringRunner; @@ -21,13 +18,9 @@ public class MetricInterpretTest { @MockBean private AgentService agentService; - @MockBean private EmbeddingConfig embeddingConfig; - @MockBean - private EmbeddingUtils embeddingUtils; - @Test public void testMetricInterpret() throws Exception { MockConfiguration.mockAgent(agentService); @@ -36,7 +29,6 @@ public class MetricInterpretTest { LLMAnswerResp lLmAnswerResp = new LLMAnswerResp(); lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多"); - MockConfiguration.embeddingUtils(embeddingUtils, ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp))); } } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java index a7ea37da1..099de2a76 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java @@ -1,20 +1,17 @@ package com.tencent.supersonic.integration; -import static org.mockito.ArgumentMatchers.anyObject; import static org.mockito.Mockito.when; import com.google.common.collect.Lists; import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; import com.tencent.supersonic.common.util.embedding.Retrieval; import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; import com.tencent.supersonic.util.DataUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.context.annotation.Configuration; -import org.springframework.http.ResponseEntity; @Configuration @Slf4j @@ -38,7 +35,4 @@ public class MockConfiguration { when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent()); } - public static void embeddingUtils(EmbeddingUtils embeddingUtils, ResponseEntity responseEntity) { - when(embeddingUtils.doRequest(anyObject(), anyObject(), anyObject())).thenReturn(responseEntity); - } } diff --git a/pom.xml b/pom.xml index f3a474c20..310f8194d 100644 --- a/pom.xml +++ b/pom.xml @@ -71,6 +71,7 @@ 22.3.0 2.2.6 3.17 + 0.24.0 @@ -94,6 +95,57 @@ guava ${guava.version} + + + dev.langchain4j + langchain4j-parent + ${langchain4j.version} + + + dev.langchain4j + langchain4j + ${langchain4j.version} + + + dev.langchain4j + langchain4j-core + ${langchain4j.version} + + + dev.langchain4j + langchain4j-spring-boot-starter + ${langchain4j.version} + + + dev.langchain4j + langchain4j-open-ai + ${langchain4j.version} + + + dev.langchain4j + langchain4j-hugging-face + ${langchain4j.version} + + + dev.langchain4j + langchain4j-chroma + ${langchain4j.version} + + + dev.langchain4j + langchain4j-embeddings + ${langchain4j.version} + + + dev.langchain4j + langchain4j-hugging-face + ${langchain4j.version} + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2 + ${langchain4j.version} + diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/listener/MetaEmbeddingListener.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/listener/MetaEmbeddingListener.java index 44737aa45..402bff338 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/listener/MetaEmbeddingListener.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/listener/MetaEmbeddingListener.java @@ -4,13 +4,13 @@ import com.alibaba.fastjson.JSONObject; import com.tencent.supersonic.common.pojo.DataEvent; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.pojo.enums.EventType; +import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; -import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; +import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore; import java.util.List; import java.util.Map; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.ApplicationListener; import org.springframework.scheduling.annotation.Async; @@ -23,8 +23,7 @@ public class MetaEmbeddingListener implements ApplicationListener { public static final String COLLECTION_NAME = "meta_collection"; - @Autowired - private EmbeddingUtils embeddingUtils; + private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore(); @Value("${embedding.operation.sleep.time:3000}") private Integer embeddingOperationSleepTime; @@ -55,14 +54,14 @@ public class MetaEmbeddingListener implements ApplicationListener { } catch (InterruptedException e) { log.error("", e); } - embeddingUtils.addCollection(COLLECTION_NAME); + s2EmbeddingStore.addCollection(COLLECTION_NAME); if (event.getEventType().equals(EventType.ADD)) { - embeddingUtils.addQuery(COLLECTION_NAME, embeddingQueries); + s2EmbeddingStore.addQuery(COLLECTION_NAME, embeddingQueries); } else if (event.getEventType().equals(EventType.DELETE)) { - embeddingUtils.deleteQuery(COLLECTION_NAME, embeddingQueries); + s2EmbeddingStore.deleteQuery(COLLECTION_NAME, embeddingQueries); } else if (event.getEventType().equals(EventType.UPDATE)) { - embeddingUtils.deleteQuery(COLLECTION_NAME, embeddingQueries); - embeddingUtils.addQuery(COLLECTION_NAME, embeddingQueries); + s2EmbeddingStore.deleteQuery(COLLECTION_NAME, embeddingQueries); + s2EmbeddingStore.addQuery(COLLECTION_NAME, embeddingQueries); } }