mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement](chat) remove Python service and rewrite it using a Java project (#428)
This commit is contained in:
@@ -116,19 +116,6 @@
|
||||
<version>${mockito-inline.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<!--langchain4j-->
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-chroma</artifactId>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
package com.tencent.supersonic.chat.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.llm.prompt.FunctionCallPromptGenerator;
|
||||
import com.tencent.supersonic.chat.llm.prompt.OutputFormat;
|
||||
import com.tencent.supersonic.chat.llm.prompt.SqlExampleLoader;
|
||||
import com.tencent.supersonic.chat.llm.prompt.SqlPromptGenerator;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class EmbedLLMInterpreter implements LLMInterpreter {
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
|
||||
SqlExampleLoader sqlExampleLoader = ContextUtils.getBean(SqlExampleLoader.class);
|
||||
|
||||
OptimizationConfig config = ContextUtils.getBean(OptimizationConfig.class);
|
||||
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
config.getText2sqlCollectionName(), config.getText2sqlFewShotsNum());
|
||||
|
||||
String queryText = llmReq.getQueryText();
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
|
||||
SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class);
|
||||
String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList,
|
||||
linking, sqlExamples);
|
||||
|
||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||
Response<AiMessage> linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
||||
|
||||
String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text());
|
||||
|
||||
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr,
|
||||
llmReq.getCurrentDate(), sqlExamples);
|
||||
|
||||
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
|
||||
|
||||
LLMResp result = new LLMResp();
|
||||
result.setQuery(queryText);
|
||||
result.setSchemaLinkingOutput(linkingPromptStr);
|
||||
result.setSchemaLinkStr(schemaLinkStr);
|
||||
result.setModelName(modelName);
|
||||
result.setSqlOutput(sqlResult.content().text());
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FunctionResp requestFunction(FunctionReq functionReq) {
|
||||
|
||||
FunctionCallPromptGenerator promptGenerator = ContextUtils.getBean(FunctionCallPromptGenerator.class);
|
||||
|
||||
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
|
||||
functionReq.getPluginConfigs());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
|
||||
String functionSelect = chatLanguageModel.generate(functionCallPrompt);
|
||||
|
||||
return OutputFormat.functionCallParse(functionSelect);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package com.tencent.supersonic.chat.llm.listener;
|
||||
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.llm.EmbedLLMInterpreter;
|
||||
import com.tencent.supersonic.chat.llm.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.llm.prompt.SqlExample;
|
||||
import com.tencent.supersonic.chat.llm.prompt.SqlExampleLoader;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.CommandLineRunner;
|
||||
import org.springframework.core.annotation.Order;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@Order(4)
|
||||
public class EmbeddingInitListener implements CommandLineRunner {
|
||||
|
||||
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
|
||||
@Autowired
|
||||
private SqlExampleLoader sqlExampleLoader;
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Override
|
||||
public void run(String... args) {
|
||||
initSqlExamples();
|
||||
}
|
||||
|
||||
public void initSqlExamples() {
|
||||
try {
|
||||
if (llmInterpreter instanceof EmbedLLMInterpreter) {
|
||||
List<SqlExample> sqlExamples = sqlExampleLoader.getSqlExamples();
|
||||
String collectionName = optimizationConfig.getText2sqlCollectionName();
|
||||
sqlExampleLoader.addEmbeddingStore(sqlExamples, collectionName);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("initSqlExamples error", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package com.tencent.supersonic.chat.llm.prompt;
|
||||
|
||||
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class FunctionCallPromptGenerator {
|
||||
|
||||
public String generateFunctionCallPrompt(String queryText, List<PluginParseConfig> toolConfigList) {
|
||||
List<String> toolExplainList = toolConfigList.stream()
|
||||
.map(this::constructPluginPrompt)
|
||||
.collect(Collectors.toList());
|
||||
String functionList = String.join(InputFormat.SEPERATOR, toolExplainList);
|
||||
return constructTaskPrompt(queryText, functionList);
|
||||
}
|
||||
|
||||
public String constructPluginPrompt(PluginParseConfig parseConfig) {
|
||||
String toolName = parseConfig.getName();
|
||||
String toolDescription = parseConfig.getDescription();
|
||||
List<String> toolExamples = parseConfig.getExamples();
|
||||
|
||||
StringBuilder prompt = new StringBuilder();
|
||||
prompt.append("【工具名称】\n").append(toolName).append("\n");
|
||||
prompt.append("【工具描述】\n").append(toolDescription).append("\n");
|
||||
prompt.append("【工具适用问题示例】\n");
|
||||
for (String example : toolExamples) {
|
||||
prompt.append(example).append("\n");
|
||||
}
|
||||
return prompt.toString();
|
||||
}
|
||||
|
||||
public String constructTaskPrompt(String queryText, String functionList) {
|
||||
String instruction = String.format("问题为:%s\n请根据问题和工具的描述,选择对应的工具,完成任务。"
|
||||
+ "请注意,只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据),"
|
||||
+ "并给出最终选择,输出格式为json,key为’分析过程‘, ’选择工具‘", queryText);
|
||||
|
||||
return String.format("工具选择如下:\n\n%s\n\n【任务说明】\n%s", functionList, instruction);
|
||||
}
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
package com.tencent.supersonic.chat.llm.prompt;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class InputFormat {
|
||||
|
||||
public static final String SEPERATOR = "\n\n";
|
||||
|
||||
public static String format(String template, List<String> templateKey,
|
||||
List<Map<String, String>> toFormatList) {
|
||||
List<String> result = new ArrayList<>();
|
||||
|
||||
for (Map<String, String> formatItem : toFormatList) {
|
||||
Map<String, String> retrievalMeta = subDict(formatItem, templateKey);
|
||||
result.add(format(template, retrievalMeta));
|
||||
}
|
||||
|
||||
return String.join(SEPERATOR, result);
|
||||
}
|
||||
|
||||
|
||||
public static String format(String input, Map<String, String> replacements) {
|
||||
for (Map.Entry<String, String> entry : replacements.entrySet()) {
|
||||
input = input.replace(entry.getKey(), entry.getValue());
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
private static Map<String, String> subDict(Map<String, String> dict, List<String> keys) {
|
||||
Map<String, String> subDict = new HashMap<>();
|
||||
for (String key : keys) {
|
||||
if (dict.containsKey(key)) {
|
||||
subDict.put(key, dict.get(key));
|
||||
}
|
||||
}
|
||||
return subDict;
|
||||
}
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
package com.tencent.supersonic.chat.llm.prompt;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/***
|
||||
* output format
|
||||
*/
|
||||
@Slf4j
|
||||
public class OutputFormat {
|
||||
|
||||
public static final String PATTERN = "\\{[^{}]+\\}";
|
||||
|
||||
public static String schemaLinkParse(String schemaLinkOutput) {
|
||||
try {
|
||||
schemaLinkOutput = schemaLinkOutput.trim();
|
||||
String pattern = "Schema_links:(.*)";
|
||||
Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL);
|
||||
Matcher matcher = regexPattern.matcher(schemaLinkOutput);
|
||||
if (matcher.find()) {
|
||||
schemaLinkOutput = matcher.group(1).trim();
|
||||
} else {
|
||||
schemaLinkOutput = null;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
schemaLinkOutput = null;
|
||||
}
|
||||
return schemaLinkOutput;
|
||||
}
|
||||
|
||||
|
||||
public static FunctionResp functionCallParse(String llmOutput) {
|
||||
try {
|
||||
String[] findResult = llmOutput.split(PATTERN);
|
||||
String result = findResult[0].trim();
|
||||
|
||||
Map<String, String> resultDict = JsonUtil.toMap(result, String.class, String.class);
|
||||
log.info("result:{},resultDict:{}", result, resultDict);
|
||||
|
||||
String selection = resultDict.get("选择工具");
|
||||
FunctionResp resp = new FunctionResp();
|
||||
resp.setToolSelection(selection);
|
||||
return resp;
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package com.tencent.supersonic.chat.llm.prompt;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class SqlExample {
|
||||
|
||||
@JsonProperty("currentDate")
|
||||
private String currentDate;
|
||||
|
||||
@JsonProperty("tableName")
|
||||
private String tableName;
|
||||
|
||||
@JsonProperty("fieldsList")
|
||||
private String fieldsList;
|
||||
|
||||
@JsonProperty("question")
|
||||
private String question;
|
||||
|
||||
@JsonProperty("priorSchemaLinks")
|
||||
private String priorSchemaLinks;
|
||||
|
||||
@JsonProperty("analysis")
|
||||
private String analysis;
|
||||
|
||||
@JsonProperty("schemaLinks")
|
||||
private String schemaLinks;
|
||||
|
||||
@JsonProperty("sql")
|
||||
private String sql;
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package com.tencent.supersonic.chat.llm.prompt;
|
||||
|
||||
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.tencent.supersonic.chat.llm.vectordb.EmbeddingStoreOperator;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class SqlExampleLoader {
|
||||
|
||||
private static final String EXAMPLE_JSON_FILE = "example.json";
|
||||
@Autowired
|
||||
private EmbeddingStoreOperator embeddingStoreOperator;
|
||||
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
|
||||
};
|
||||
|
||||
public List<SqlExample> getSqlExamples() throws IOException {
|
||||
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
||||
InputStream inputStream = resource.getInputStream();
|
||||
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
|
||||
}
|
||||
|
||||
public void addEmbeddingStore(List<SqlExample> sqlExamples, String collectionName) {
|
||||
embeddingStoreOperator.addAll(sqlExamples, collectionName);
|
||||
}
|
||||
|
||||
public List<Map<String, String>> retrieverSqlExamples(String queryText, String collectionName, int maxResults) {
|
||||
List<TextSegment> textSegments = embeddingStoreOperator.retriever(queryText, collectionName, maxResults);
|
||||
|
||||
List<Map<String, String>> result = new ArrayList<>();
|
||||
for (TextSegment textSegment : textSegments) {
|
||||
if (Objects.nonNull(textSegment.metadata())) {
|
||||
result.add(textSegment.metadata().asMap());
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
package com.tencent.supersonic.chat.llm.prompt;
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class SqlPromptGenerator {
|
||||
|
||||
public String generateSchemaLinkingPrompt(String question, String modelName, List<String> fieldsList,
|
||||
List<ElementValue> priorSchemaLinks, List<Map<String, String>> exampleList) {
|
||||
|
||||
String exampleTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n"
|
||||
+ "问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}";
|
||||
|
||||
List<String> exampleKeys = Arrays.asList("tableName", "fieldsList", "priorSchemaLinks", "question", "analysis",
|
||||
"schemaLinks");
|
||||
|
||||
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
|
||||
String newCaseTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n"
|
||||
+ "问题:{question}\n分析: 让我们一步一步地思考。";
|
||||
|
||||
String newCasePrompt = newCaseTemplate.replace("{tableName}", modelName)
|
||||
.replace("{fieldsList}", fieldsList.toString())
|
||||
.replace("{priorSchemaLinks}", getPriorSchemaLinks(priorSchemaLinks))
|
||||
.replace("{question}", question);
|
||||
|
||||
String instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links";
|
||||
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
private String getPriorSchemaLinks(List<ElementValue> priorSchemaLinks) {
|
||||
return priorSchemaLinks.stream()
|
||||
.map(elementValue -> "'" + elementValue.getFieldName() + "'->" + elementValue.getFieldValue())
|
||||
.collect(Collectors.joining(",", "[", "]"));
|
||||
}
|
||||
|
||||
public String generateSqlPrompt(String question, String modelName, String schemaLinkStr, String dataDate,
|
||||
List<Map<String, String>> exampleList) {
|
||||
|
||||
List<String> exampleKeys = Arrays.asList("question", "currentDate", "tableName", "schemaLinks", "sql");
|
||||
String exampleTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n"
|
||||
+ "Schema_links:{schemaLinks}\nSQL:{sql}";
|
||||
|
||||
String sqlExamplePrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
|
||||
String newCaseTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n"
|
||||
+ "Schema_links:{schemaLinks}\nSQL:";
|
||||
|
||||
String newCasePrompt = newCaseTemplate.replace("{question}", question)
|
||||
.replace("{currentDate}", dataDate)
|
||||
.replace("{tableName}", modelName)
|
||||
.replace("{schemaLinks}", schemaLinkStr);
|
||||
|
||||
String instruction = "# 根据schema_links为每个问题生成SQL查询语句";
|
||||
return instruction + InputFormat.SEPERATOR + sqlExamplePrompt + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package com.tencent.supersonic.chat.llm.vectordb;
|
||||
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class EmbeddingStoreFactory {
|
||||
|
||||
private static Map<String, EmbeddingStore> collectionNameToStore = new ConcurrentHashMap<>();
|
||||
|
||||
|
||||
public static EmbeddingStore create(String collectionName) {
|
||||
return collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore());
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
package com.tencent.supersonic.chat.llm.vectordb;
|
||||
|
||||
import com.tencent.supersonic.chat.llm.prompt.SqlExample;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class EmbeddingStoreOperator {
|
||||
|
||||
@Autowired
|
||||
private EmbeddingModel embeddingModel;
|
||||
|
||||
public List<TextSegment> retriever(String text, String collectionName, int maxResults) {
|
||||
EmbeddingStore embeddingStore = EmbeddingStoreFactory.create(collectionName);
|
||||
EmbeddingStoreRetriever retriever = EmbeddingStoreRetriever.from(embeddingStore, embeddingModel, maxResults);
|
||||
return retriever.findRelevant(text);
|
||||
}
|
||||
|
||||
public List<String> addAll(List<SqlExample> sqlExamples, String collectionName) {
|
||||
List<Embedding> embeddings = new ArrayList<>();
|
||||
List<TextSegment> textSegments = new ArrayList<>();
|
||||
|
||||
for (SqlExample sqlExample : sqlExamples) {
|
||||
String question = sqlExample.getQuestion();
|
||||
Embedding embedding = embeddingModel.embed(question).content();
|
||||
embeddings.add(embedding);
|
||||
|
||||
Map<String, String> metaDataMap = JsonUtil.toMap(JsonUtil.toString(sqlExample), String.class,
|
||||
String.class);
|
||||
|
||||
TextSegment textSegment = TextSegment.from(question, new Metadata(metaDataMap));
|
||||
textSegments.add(textSegment);
|
||||
}
|
||||
return addAllInternal(embeddings, textSegments, collectionName);
|
||||
}
|
||||
|
||||
private List<String> addAllInternal(List<Embedding> embeddings, List<TextSegment> textSegments,
|
||||
String collectionName) {
|
||||
EmbeddingStore embeddingStore = EmbeddingStoreFactory.create(collectionName);
|
||||
return embeddingStore.addAll(embeddings, textSegments);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.llm;
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
@@ -9,6 +9,8 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
@@ -17,8 +19,6 @@ import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
|
||||
@Slf4j
|
||||
public class HttpLLMInterpreter implements LLMInterpreter {
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.llm;
|
||||
package com.tencent.supersonic.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
|
||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.llm.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
|
||||
|
||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.parser.plugin.embedding;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.llm.HttpLLMInterpreter;
|
||||
import com.tencent.supersonic.chat.llm.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.HttpLLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.function;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.llm.HttpLLMInterpreter;
|
||||
import com.tencent.supersonic.chat.llm.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.HttpLLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
|
||||
@@ -4,7 +4,7 @@ import com.tencent.supersonic.chat.api.component.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticParser;
|
||||
import com.tencent.supersonic.chat.llm.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.LLMInterpreter;
|
||||
import com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver;
|
||||
import com.tencent.supersonic.chat.postprocessor.PostProcessor;
|
||||
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
|
||||
|
||||
@@ -19,8 +19,8 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.HavingCorrector
|
||||
|
||||
com.tencent.supersonic.chat.llm.LLMInterpreter=\
|
||||
com.tencent.supersonic.chat.llm.HttpLLMInterpreter
|
||||
com.tencent.supersonic.chat.parser.LLMInterpreter=\
|
||||
com.tencent.supersonic.chat.parser.HttpLLMInterpreter
|
||||
|
||||
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
||||
com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter
|
||||
|
||||
@@ -96,16 +96,6 @@
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<!--langchain4j-->
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-spring-boot-starter</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
package com.tencent.supersonic.config;
|
||||
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
public class LangChain4jConfig {
|
||||
|
||||
@Bean
|
||||
EmbeddingModel embeddingModel() {
|
||||
return new AllMiniLmL6V2EmbeddingModel();
|
||||
}
|
||||
}
|
||||
@@ -21,8 +21,8 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||
com.tencent.supersonic.chat.corrector.HavingCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.FromCorrector
|
||||
|
||||
com.tencent.supersonic.chat.llm.LLMInterpreter=\
|
||||
com.tencent.supersonic.chat.llm.HttpLLMInterpreter
|
||||
com.tencent.supersonic.chat.parser.LLMInterpreter=\
|
||||
com.tencent.supersonic.chat.parser.HttpLLMInterpreter
|
||||
|
||||
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
||||
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter
|
||||
@@ -49,6 +49,4 @@ com.tencent.supersonic.chat.responder.parse.ParseResponder=\
|
||||
|
||||
com.tencent.supersonic.chat.responder.execute.ExecuteResponder=\
|
||||
com.tencent.supersonic.chat.responder.execute.EntityInfoExecuteResponder, \
|
||||
com.tencent.supersonic.chat.responder.execute.SimilarMetricExecuteResponder
|
||||
|
||||
org.springframework.boot.autoconfigure.EnableAutoConfiguration=dev.langchain4j.LangChain4jAutoConfiguration
|
||||
com.tencent.supersonic.chat.responder.execute.SimilarMetricExecuteResponder
|
||||
@@ -39,19 +39,4 @@ llm:
|
||||
embedding:
|
||||
url: http://127.0.0.1:9092
|
||||
functionCall:
|
||||
url: http://127.0.0.1:9092
|
||||
|
||||
|
||||
langchain4j:
|
||||
chat-model:
|
||||
provider: open_ai
|
||||
openai:
|
||||
api-key: api_key
|
||||
model-name: gpt-3.5-turbo
|
||||
temperature: 0.0
|
||||
timeout: PT60S
|
||||
|
||||
logging:
|
||||
level:
|
||||
dev.langchain4j: DEBUG
|
||||
dev.ai4j.openai4j: DEBUG
|
||||
url: http://127.0.0.1:9092
|
||||
52
pom.xml
52
pom.xml
@@ -71,7 +71,6 @@
|
||||
<spotless.python.black.version>22.3.0</spotless.python.black.version>
|
||||
<easyexcel.version>2.2.6</easyexcel.version>
|
||||
<poi.version>3.17</poi.version>
|
||||
<langchain4j.version>0.24.0</langchain4j.version>
|
||||
</properties>
|
||||
|
||||
<dependencyManagement>
|
||||
@@ -95,57 +94,6 @@
|
||||
<artifactId>guava</artifactId>
|
||||
<version>${guava.version}</version>
|
||||
</dependency>
|
||||
<!--langchain4j-->
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-parent</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-core</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-spring-boot-starter</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-open-ai</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-hugging-face</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-chroma</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-hugging-face</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user