[improvement](python) LLM parsing related services support Python service and Java service invocation (#418)

This commit is contained in:
lexluo09
2023-11-24 15:59:29 +08:00
committed by GitHub
parent 30bb9a1dc0
commit aa433baa06
28 changed files with 1054 additions and 103 deletions

View File

@@ -116,6 +116,19 @@
<version>${mockito-inline.version}</version>
<scope>test</scope>
</dependency>
<!--langchain4j-->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-chroma</artifactId>
</dependency>
</dependencies>

View File

@@ -59,6 +59,19 @@ public class OptimizationConfig {
@Value("${s2SQL.use.switch:true}")
private boolean useS2SqlSwitch;
@Value("${text2sql.example.num:10}")
private int text2sqlExampleNum;
@Value("${text2sql.fewShots.num:10}")
private int text2sqlFewShotsNum;
@Value("${text2sql.self.consistency.num:5}")
private int text2sqlSelfConsistencyNum;
@Value("${text2sql.collection.name:text2dsl_agent_collection}")
private String text2sqlCollectionName;
@Autowired
private SysParameterService sysParameterService;

View File

@@ -0,0 +1,83 @@
package com.tencent.supersonic.chat.llm;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.llm.prompt.FunctionCallPromptGenerator;
import com.tencent.supersonic.chat.llm.prompt.OutputFormat;
import com.tencent.supersonic.chat.llm.prompt.SqlExampleLoader;
import com.tencent.supersonic.chat.llm.prompt.SqlPromptGenerator;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class EmbedLLMInterpreter implements LLMInterpreter {
public LLMResp query2sql(LLMReq llmReq, Long modelId) {
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
SqlExampleLoader sqlExampleLoader = ContextUtils.getBean(SqlExampleLoader.class);
OptimizationConfig config = ContextUtils.getBean(OptimizationConfig.class);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
config.getText2sqlCollectionName(), config.getText2sqlFewShotsNum());
String queryText = llmReq.getQueryText();
String modelName = llmReq.getSchema().getModelName();
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
List<ElementValue> linking = llmReq.getLinking();
SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class);
String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList,
linking, sqlExamples);
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
Response<AiMessage> linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text());
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr,
llmReq.getCurrentDate(), sqlExamples);
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
LLMResp result = new LLMResp();
result.setQuery(queryText);
result.setSchemaLinkingOutput(linkingPromptStr);
result.setSchemaLinkStr(schemaLinkStr);
result.setModelName(modelName);
result.setSqlOutput(sqlResult.content().text());
return result;
}
@Override
public FunctionResp requestFunction(FunctionReq functionReq) {
FunctionCallPromptGenerator promptGenerator = ContextUtils.getBean(FunctionCallPromptGenerator.class);
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
functionReq.getPluginConfigs());
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
String functionSelect = chatLanguageModel.generate(functionCallPrompt);
return OutputFormat.functionCallParse(functionSelect);
}
}

View File

@@ -0,0 +1,71 @@
package com.tencent.supersonic.chat.llm;
import com.alibaba.fastjson.JSON;
import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallConfig;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import java.net.URI;
import java.net.URL;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
@Slf4j
public class HttpLLMInterpreter implements LLMInterpreter {
public LLMResp query2sql(LLMReq llmReq, Long modelId) {
long startTime = System.currentTimeMillis();
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
try {
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
LLMResp.class);
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
return responseEntity.getBody();
} catch (Exception e) {
log.error("requestLLM error", e);
}
return null;
}
public FunctionResp requestFunction(FunctionReq functionReq) {
FunctionCallConfig functionCallInfoConfig = ContextUtils.getBean(FunctionCallConfig.class);
String url = functionCallInfoConfig.getUrl() + functionCallInfoConfig.getPluginSelectPath();
HttpHeaders headers = new HttpHeaders();
long startTime = System.currentTimeMillis();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(functionReq), headers);
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
try {
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
FunctionResp.class);
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,
System.currentTimeMillis() - startTime);
return responseEntity.getBody();
} catch (Exception e) {
log.error("requestFunction error", e);
}
return null;
}
}

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.chat.llm;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
/**
* Unified interpreter for invoking the llm layer.
*/
public interface LLMInterpreter {
LLMResp query2sql(LLMReq llmReq, Long modelId);
FunctionResp requestFunction(FunctionReq functionReq);
}

View File

@@ -0,0 +1,43 @@
package com.tencent.supersonic.chat.llm.listener;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.llm.EmbedLLMInterpreter;
import com.tencent.supersonic.chat.llm.LLMInterpreter;
import com.tencent.supersonic.chat.llm.prompt.SqlExample;
import com.tencent.supersonic.chat.llm.prompt.SqlExampleLoader;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
@Slf4j
@Component
@Order(4)
public class EmbeddingInitListener implements CommandLineRunner {
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
@Autowired
private SqlExampleLoader sqlExampleLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Override
public void run(String... args) {
initSqlExamples();
}
public void initSqlExamples() {
try {
if (llmInterpreter instanceof EmbedLLMInterpreter) {
List<SqlExample> sqlExamples = sqlExampleLoader.getSqlExamples();
String collectionName = optimizationConfig.getText2sqlCollectionName();
sqlExampleLoader.addEmbeddingStore(sqlExamples, collectionName);
}
} catch (Exception e) {
log.error("initSqlExamples error", e);
}
}
}

View File

@@ -0,0 +1,43 @@
package com.tencent.supersonic.chat.llm.prompt;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class FunctionCallPromptGenerator {
public String generateFunctionCallPrompt(String queryText, List<PluginParseConfig> toolConfigList) {
List<String> toolExplainList = toolConfigList.stream()
.map(this::constructPluginPrompt)
.collect(Collectors.toList());
String functionList = String.join(InputFormat.SEPERATOR, toolExplainList);
return constructTaskPrompt(queryText, functionList);
}
public String constructPluginPrompt(PluginParseConfig parseConfig) {
String toolName = parseConfig.getName();
String toolDescription = parseConfig.getDescription();
List<String> toolExamples = parseConfig.getExamples();
StringBuilder prompt = new StringBuilder();
prompt.append("【工具名称】\n").append(toolName).append("\n");
prompt.append("【工具描述】\n").append(toolDescription).append("\n");
prompt.append("【工具适用问题示例】\n");
for (String example : toolExamples) {
prompt.append(example).append("\n");
}
return prompt.toString();
}
public String constructTaskPrompt(String queryText, String functionList) {
String instruction = String.format("问题为:%s\n请根据问题和工具的描述选择对应的工具完成任务。"
+ "请注意只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据)"
+ "并给出最终选择输出格式为json,key为分析过程, ’选择工具‘", queryText);
return String.format("工具选择如下:\n\n%s\n\n【任务说明】\n%s", functionList, instruction);
}
}

View File

@@ -0,0 +1,43 @@
package com.tencent.supersonic.chat.llm.prompt;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class InputFormat {
public static final String SEPERATOR = "\n\n";
public static String format(String template, List<String> templateKey,
List<Map<String, String>> toFormatList) {
List<String> result = new ArrayList<>();
for (Map<String, String> formatItem : toFormatList) {
Map<String, String> retrievalMeta = subDict(formatItem, templateKey);
result.add(format(template, retrievalMeta));
}
return String.join(SEPERATOR, result);
}
public static String format(String input, Map<String, String> replacements) {
for (Map.Entry<String, String> entry : replacements.entrySet()) {
input = input.replace(entry.getKey(), entry.getValue());
}
return input;
}
private static Map<String, String> subDict(Map<String, String> dict, List<String> keys) {
Map<String, String> subDict = new HashMap<>();
for (String key : keys) {
if (dict.containsKey(key)) {
subDict.put(key, dict.get(key));
}
}
return subDict;
}
}

View File

@@ -0,0 +1,54 @@
package com.tencent.supersonic.chat.llm.prompt;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.extern.slf4j.Slf4j;
/***
* output format
*/
@Slf4j
public class OutputFormat {
public static final String PATTERN = "\\{[^{}]+\\}";
public static String schemaLinkParse(String schemaLinkOutput) {
try {
schemaLinkOutput = schemaLinkOutput.trim();
String pattern = "Schema_links:(.*)";
Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL);
Matcher matcher = regexPattern.matcher(schemaLinkOutput);
if (matcher.find()) {
schemaLinkOutput = matcher.group(1).trim();
} else {
schemaLinkOutput = null;
}
} catch (Exception e) {
log.error("", e);
schemaLinkOutput = null;
}
return schemaLinkOutput;
}
public static FunctionResp functionCallParse(String llmOutput) {
try {
String[] findResult = llmOutput.split(PATTERN);
String result = findResult[0].trim();
Map<String, String> resultDict = JsonUtil.toMap(result, String.class, String.class);
log.info("result:{},resultDict:{}", result, resultDict);
String selection = resultDict.get("选择工具");
FunctionResp resp = new FunctionResp();
resp.setToolSelection(selection);
return resp;
} catch (Exception e) {
log.error("", e);
return null;
}
}
}

View File

@@ -0,0 +1,32 @@
package com.tencent.supersonic.chat.llm.prompt;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
@Data
public class SqlExample {
@JsonProperty("currentDate")
private String currentDate;
@JsonProperty("tableName")
private String tableName;
@JsonProperty("fieldsList")
private String fieldsList;
@JsonProperty("question")
private String question;
@JsonProperty("priorSchemaLinks")
private String priorSchemaLinks;
@JsonProperty("analysis")
private String analysis;
@JsonProperty("schemaLinks")
private String schemaLinks;
@JsonProperty("sql")
private String sql;
}

View File

@@ -0,0 +1,50 @@
package com.tencent.supersonic.chat.llm.prompt;
import com.fasterxml.jackson.core.type.TypeReference;
import com.tencent.supersonic.chat.llm.vectordb.EmbeddingStoreOperator;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.segment.TextSegment;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component;
@Slf4j
@Component
public class SqlExampleLoader {
private static final String EXAMPLE_JSON_FILE = "example.json";
@Autowired
private EmbeddingStoreOperator embeddingStoreOperator;
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
};
public List<SqlExample> getSqlExamples() throws IOException {
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
InputStream inputStream = resource.getInputStream();
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
}
public void addEmbeddingStore(List<SqlExample> sqlExamples, String collectionName) {
embeddingStoreOperator.addAll(sqlExamples, collectionName);
}
public List<Map<String, String>> retrieverSqlExamples(String queryText, String collectionName, int maxResults) {
List<TextSegment> textSegments = embeddingStoreOperator.retriever(queryText, collectionName, maxResults);
List<Map<String, String>> result = new ArrayList<>();
for (TextSegment textSegment : textSegments) {
if (Objects.nonNull(textSegment.metadata())) {
result.add(textSegment.metadata().asMap());
}
}
return result;
}
}

View File

@@ -0,0 +1,66 @@
package com.tencent.supersonic.chat.llm.prompt;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class SqlPromptGenerator {
public String generateSchemaLinkingPrompt(String question, String modelName, List<String> fieldsList,
List<ElementValue> priorSchemaLinks, List<Map<String, String>> exampleList) {
String exampleTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n"
+ "问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}";
List<String> exampleKeys = Arrays.asList("tableName", "fieldsList", "priorSchemaLinks", "question", "analysis",
"schemaLinks");
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
String newCaseTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n"
+ "问题:{question}\n分析: 让我们一步一步地思考。";
String newCasePrompt = newCaseTemplate.replace("{tableName}", modelName)
.replace("{fieldsList}", fieldsList.toString())
.replace("{priorSchemaLinks}", getPriorSchemaLinks(priorSchemaLinks))
.replace("{question}", question);
String instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links";
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
}
private String getPriorSchemaLinks(List<ElementValue> priorSchemaLinks) {
return priorSchemaLinks.stream()
.map(elementValue -> "'" + elementValue.getFieldName() + "'->" + elementValue.getFieldValue())
.collect(Collectors.joining(",", "[", "]"));
}
public String generateSqlPrompt(String question, String modelName, String schemaLinkStr, String dataDate,
List<Map<String, String>> exampleList) {
List<String> exampleKeys = Arrays.asList("question", "currentDate", "tableName", "schemaLinks", "sql");
String exampleTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n"
+ "Schema_links:{schemaLinks}\nSQL:{sql}";
String sqlExamplePrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
String newCaseTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n"
+ "Schema_links:{schemaLinks}\nSQL:";
String newCasePrompt = newCaseTemplate.replace("{question}", question)
.replace("{currentDate}", dataDate)
.replace("{tableName}", modelName)
.replace("{schemaLinks}", schemaLinkStr);
String instruction = "# 根据schema_links为每个问题生成SQL查询语句";
return instruction + InputFormat.SEPERATOR + sqlExamplePrompt + InputFormat.SEPERATOR + newCasePrompt;
}
}

View File

@@ -0,0 +1,20 @@
package com.tencent.supersonic.chat.llm.vectordb;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class EmbeddingStoreFactory {
private static Map<String, EmbeddingStore> collectionNameToStore = new ConcurrentHashMap<>();
public static EmbeddingStore create(String collectionName) {
return collectionNameToStore.computeIfAbsent(collectionName, k -> new InMemoryEmbeddingStore());
}
}

View File

@@ -0,0 +1,55 @@
package com.tencent.supersonic.chat.llm.vectordb;
import com.tencent.supersonic.chat.llm.prompt.SqlExample;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@Slf4j
public class EmbeddingStoreOperator {
@Autowired
private EmbeddingModel embeddingModel;
public List<TextSegment> retriever(String text, String collectionName, int maxResults) {
EmbeddingStore embeddingStore = EmbeddingStoreFactory.create(collectionName);
EmbeddingStoreRetriever retriever = EmbeddingStoreRetriever.from(embeddingStore, embeddingModel, maxResults);
return retriever.findRelevant(text);
}
public List<String> addAll(List<SqlExample> sqlExamples, String collectionName) {
List<Embedding> embeddings = new ArrayList<>();
List<TextSegment> textSegments = new ArrayList<>();
for (SqlExample sqlExample : sqlExamples) {
String question = sqlExample.getQuestion();
Embedding embedding = embeddingModel.embed(question).content();
embeddings.add(embedding);
Map<String, String> metaDataMap = JsonUtil.toMap(JsonUtil.toString(sqlExample), String.class,
String.class);
TextSegment textSegment = TextSegment.from(question, new Metadata(metaDataMap));
textSegments.add(textSegment);
}
return addAllInternal(embeddings, textSegments, collectionName);
}
private List<String> addAllInternal(List<Embedding> embeddings, List<TextSegment> textSegments,
String collectionName) {
EmbeddingStore embeddingStore = EmbeddingStoreFactory.create(collectionName);
return embeddingStore.addAll(embeddings, textSegments);
}
}

View File

@@ -17,7 +17,7 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.LLMParserLayer;
import com.tencent.supersonic.chat.llm.LLMInterpreter;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
@@ -46,6 +46,8 @@ import org.springframework.util.CollectionUtils;
@Service
public class LLMRequestService {
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
@Autowired
private LLMParserConfig llmParserConfig;
@@ -55,8 +57,7 @@ public class LLMRequestService {
private SchemaService schemaService;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private LLMParserLayer llmParserLayer;
public boolean check(QueryContext queryCtx) {
QueryReq request = queryCtx.getRequest();
@@ -137,7 +138,7 @@ public class LLMRequestService {
}
public LLMResp requestLLM(LLMReq llmReq, Long modelId) {
return llmParserLayer.query2sql(llmReq, modelId);
return llmInterpreter.query2sql(llmReq, modelId);
}
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) {

View File

@@ -2,11 +2,14 @@ package com.tencent.supersonic.chat.parser.plugin.embedding;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.llm.HttpLLMInterpreter;
import com.tencent.supersonic.chat.llm.LLMInterpreter;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.Comparator;
@@ -22,10 +25,12 @@ import org.springframework.util.CollectionUtils;
@Slf4j
public class EmbeddingBasedParser extends PluginParser {
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
@Override
public boolean checkPreCondition(QueryContext queryContext) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (StringUtils.isBlank(embeddingConfig.getUrl())) {
if (StringUtils.isBlank(embeddingConfig.getUrl()) && llmInterpreter instanceof HttpLLMInterpreter) {
return false;
}
List<Plugin> plugins = getPluginList(queryContext);

View File

@@ -1,7 +1,8 @@
package com.tencent.supersonic.chat.parser.plugin.function;
import com.alibaba.fastjson.JSON;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.llm.HttpLLMInterpreter;
import com.tencent.supersonic.chat.llm.LLMInterpreter;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
import com.tencent.supersonic.chat.plugin.Plugin;
@@ -10,34 +11,29 @@ import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.Objects;
import java.util.stream.Collectors;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
@Slf4j
public class FunctionBasedParser extends PluginParser {
protected LLMInterpreter llmInterpreter = ComponentFactory.getLLMInterpreter();
@Override
public boolean checkPreCondition(QueryContext queryContext) {
FunctionCallConfig functionCallConfig = ContextUtils.getBean(FunctionCallConfig.class);
String functionUrl = functionCallConfig.getUrl();
if (StringUtils.isBlank(functionUrl)) {
if (StringUtils.isBlank(functionUrl) && llmInterpreter instanceof HttpLLMInterpreter) {
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
queryContext.getRequest().getQueryText());
return false;
@@ -88,7 +84,7 @@ public class FunctionBasedParser extends PluginParser {
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryContext.getRequest().getQueryText())
.pluginConfigs(pluginToFunctionCall).build();
functionResp = requestFunction(functionReq);
functionResp = llmInterpreter.requestFunction(functionReq);
}
return functionResp;
}
@@ -131,25 +127,4 @@ public class FunctionBasedParser extends PluginParser {
return functionDOList;
}
public FunctionResp requestFunction(FunctionReq functionReq) {
FunctionCallConfig functionCallInfoConfig = ContextUtils.getBean(FunctionCallConfig.class);
String url = functionCallInfoConfig.getUrl() + functionCallInfoConfig.getPluginSelectPath();
HttpHeaders headers = new HttpHeaders();
long startTime = System.currentTimeMillis();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(functionReq), headers);
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
try {
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
FunctionResp.class);
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,
System.currentTimeMillis() - startTime);
return responseEntity.getBody();
} catch (Exception e) {
log.error("requestFunction error", e);
}
return null;
}
}

View File

@@ -1,13 +0,0 @@
package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
/**
* Unified wrapper for invoking the llmparser Python service layer.
*/
public interface LLMParserLayer {
LLMResp query2sql(LLMReq llmReq, Long modelId);
}

View File

@@ -1,47 +0,0 @@
package com.tencent.supersonic.chat.service.impl;
import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.service.LLMParserLayer;
import com.tencent.supersonic.common.util.JsonUtil;
import java.net.URL;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
@Service
@Slf4j
public class LLMParserLayerImpl implements LLMParserLayer {
@Autowired
private RestTemplate restTemplate;
@Autowired
private LLMParserConfig llmParserConfig;
public LLMResp query2sql(LLMReq llmReq, Long modelId) {
long startTime = System.currentTimeMillis();
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
try {
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
LLMResp.class);
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url, entity, responseEntity.getBody());
return responseEntity.getBody();
} catch (Exception e) {
log.error("requestLLM error", e);
}
return null;
}
}

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.llm.LLMInterpreter;
import com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver;
import com.tencent.supersonic.chat.postprocessor.PostProcessor;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
@@ -20,10 +21,13 @@ public class ComponentFactory {
private static List<SemanticParser> semanticParsers = new ArrayList<>();
private static List<SemanticCorrector> s2SQLCorrections = new ArrayList<>();
private static SemanticInterpreter semanticInterpreter;
private static LLMInterpreter llmInterpreter;
private static List<PostProcessor> postProcessors = new ArrayList<>();
private static List<ParseResponder> parseResponders = new ArrayList<>();
private static List<ExecuteResponder> executeResponders = new ArrayList<>();
private static ModelResolver modelResolver;
public static List<SchemaMapper> getSchemaMappers() {
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers;
}
@@ -62,6 +66,13 @@ public class ComponentFactory {
}
public static LLMInterpreter getLLMInterpreter() {
if (Objects.isNull(llmInterpreter)) {
llmInterpreter = init(LLMInterpreter.class);
}
return llmInterpreter;
}
public static ModelResolver getModelResolver() {
if (Objects.isNull(modelResolver)) {
modelResolver = init(ModelResolver.class);