[improvement](python) LLM related services support Java service invocation (#484)

This commit is contained in:
lexluo09
2023-12-08 19:24:58 +08:00
committed by GitHub
parent 6c0f88d8b5
commit abbe8c84a1
33 changed files with 1037 additions and 95 deletions

View File

@@ -104,6 +104,19 @@
<version>${mockito-inline.version}</version>
<scope>test</scope>
</dependency>
<!--langchain4j-->
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-chroma</artifactId>
</dependency>
</dependencies>

View File

@@ -4,10 +4,11 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import com.tencent.supersonic.knowledge.dictionary.EmbeddingResult;
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
import java.util.Comparator;
@@ -32,8 +33,8 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private EmbeddingUtils embeddingUtils;
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@Override
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
@@ -83,7 +84,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
.queryEmbeddings(null)
.build();
// step2. retrieveQuery by detectSegment
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, embeddingNumber);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
@@ -97,7 +98,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
if (CollectionUtils.isNotEmpty(detectModelIds)) {
retrievals.removeIf(retrieval -> {
String modelIdStr = retrieval.getMetadata().get("modelId");
String modelIdStr = retrieval.getMetadata().get("modelId").toString();
if (StringUtils.isBlank(modelIdStr)) {
return true;
}

View File

@@ -0,0 +1,84 @@
package com.tencent.supersonic.chat.parser;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionCallPromptGenerator;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.OutputFormat;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlExampleLoader;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.SqlPromptGenerator;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionReq;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Slf4j
public class EmbedLLMProxy implements LLMProxy {
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
SqlExampleLoader sqlExampleLoader = ContextUtils.getBean(SqlExampleLoader.class);
OptimizationConfig config = ContextUtils.getBean(OptimizationConfig.class);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
config.getText2sqlCollectionName(), config.getText2sqlFewShotsNum());
String queryText = llmReq.getQueryText();
String modelName = llmReq.getSchema().getModelName();
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
List<ElementValue> linking = llmReq.getLinking();
SqlPromptGenerator sqlPromptGenerator = ContextUtils.getBean(SqlPromptGenerator.class);
String linkingPromptStr = sqlPromptGenerator.generateSchemaLinkingPrompt(queryText, modelName, fieldNameList,
linking, sqlExamples);
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
Response<AiMessage> linkingResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
String schemaLinkStr = OutputFormat.schemaLinkParse(linkingResult.content().text());
String generateSqlPrompt = sqlPromptGenerator.generateSqlPrompt(queryText, modelName, schemaLinkStr,
llmReq.getCurrentDate(), sqlExamples);
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
LLMResp result = new LLMResp();
result.setQuery(queryText);
result.setSchemaLinkingOutput(linkingPromptStr);
result.setSchemaLinkStr(schemaLinkStr);
result.setModelName(modelName);
result.setSqlOutput(sqlResult.content().text());
return result;
}
@Override
public FunctionResp requestFunction(FunctionReq functionReq) {
FunctionCallPromptGenerator promptGenerator = ContextUtils.getBean(FunctionCallPromptGenerator.class);
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
functionReq.getPluginConfigs());
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
String functionSelect = chatLanguageModel.generate(functionCallPrompt);
return OutputFormat.functionCallParse(functionSelect);
}
}

View File

@@ -0,0 +1,44 @@
package com.tencent.supersonic.chat.parser.plugin.function;
import com.tencent.supersonic.chat.parser.sql.llm.prompt.InputFormat;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class FunctionCallPromptGenerator {
public String generateFunctionCallPrompt(String queryText, List<PluginParseConfig> toolConfigList) {
List<String> toolExplainList = toolConfigList.stream()
.map(this::constructPluginPrompt)
.collect(Collectors.toList());
String functionList = String.join(InputFormat.SEPERATOR, toolExplainList);
return constructTaskPrompt(queryText, functionList);
}
public String constructPluginPrompt(PluginParseConfig parseConfig) {
String toolName = parseConfig.getName();
String toolDescription = parseConfig.getDescription();
List<String> toolExamples = parseConfig.getExamples();
StringBuilder prompt = new StringBuilder();
prompt.append("【工具名称】\n").append(toolName).append("\n");
prompt.append("【工具描述】\n").append(toolDescription).append("\n");
prompt.append("【工具适用问题示例】\n");
for (String example : toolExamples) {
prompt.append(example).append("\n");
}
return prompt.toString();
}
public String constructTaskPrompt(String queryText, String functionList) {
String instruction = String.format("问题为:%s\n请根据问题和工具的描述选择对应的工具完成任务。"
+ "请注意只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据)"
+ "并给出最终选择输出格式为json,key为分析过程, ’选择工具‘", queryText);
return String.format("工具选择如下:\n\n%s\n\n【任务说明】\n%s", functionList, instruction);
}
}

View File

@@ -0,0 +1,42 @@
package com.tencent.supersonic.chat.parser.sql.llm.prompt;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class InputFormat {
public static final String SEPERATOR = "\n\n";
public static String format(String template, List<String> templateKey,
List<Map<String, String>> toFormatList) {
List<String> result = new ArrayList<>();
for (Map<String, String> formatItem : toFormatList) {
Map<String, String> retrievalMeta = subDict(formatItem, templateKey);
result.add(format(template, retrievalMeta));
}
return String.join(SEPERATOR, result);
}
public static String format(String input, Map<String, String> replacements) {
for (Map.Entry<String, String> entry : replacements.entrySet()) {
input = input.replace(entry.getKey(), entry.getValue());
}
return input;
}
private static Map<String, String> subDict(Map<String, String> dict, List<String> keys) {
Map<String, String> subDict = new HashMap<>();
for (String key : keys) {
if (dict.containsKey(key)) {
subDict.put(key, dict.get(key));
}
}
return subDict;
}
}

View File

@@ -0,0 +1,53 @@
package com.tencent.supersonic.chat.parser.sql.llm.prompt;
import com.tencent.supersonic.chat.parser.plugin.function.FunctionResp;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.extern.slf4j.Slf4j;
/***
* output format
*/
@Slf4j
public class OutputFormat {
public static final String PATTERN = "\\{[^{}]+\\}";
public static String schemaLinkParse(String schemaLinkOutput) {
try {
schemaLinkOutput = schemaLinkOutput.trim();
String pattern = "Schema_links:(.*)";
Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL);
Matcher matcher = regexPattern.matcher(schemaLinkOutput);
if (matcher.find()) {
schemaLinkOutput = matcher.group(1).trim();
} else {
schemaLinkOutput = null;
}
} catch (Exception e) {
log.error("", e);
schemaLinkOutput = null;
}
return schemaLinkOutput;
}
public static FunctionResp functionCallParse(String llmOutput) {
try {
String[] findResult = llmOutput.split(PATTERN);
String result = findResult[0].trim();
Map<String, String> resultDict = JsonUtil.toMap(result, String.class, String.class);
log.info("result:{},resultDict:{}", result, resultDict);
String selection = resultDict.get("选择工具");
FunctionResp resp = new FunctionResp();
resp.setToolSelection(selection);
return resp;
} catch (Exception e) {
log.error("", e);
return null;
}
}
}

View File

@@ -0,0 +1,32 @@
package com.tencent.supersonic.chat.parser.sql.llm.prompt;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
@Data
public class SqlExample {
@JsonProperty("currentDate")
private String currentDate;
@JsonProperty("tableName")
private String tableName;
@JsonProperty("fieldsList")
private String fieldsList;
@JsonProperty("question")
private String question;
@JsonProperty("priorSchemaLinks")
private String priorSchemaLinks;
@JsonProperty("analysis")
private String analysis;
@JsonProperty("schemaLinks")
private String schemaLinks;
@JsonProperty("sql")
private String sql;
}

View File

@@ -0,0 +1,76 @@
package com.tencent.supersonic.chat.parser.sql.llm.prompt;
import com.fasterxml.jackson.core.type.TypeReference;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component;
@Slf4j
@Component
public class SqlExampleLoader {
private static final String EXAMPLE_JSON_FILE = "example.json";
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
private TypeReference<List<SqlExample>> valueTypeRef = new TypeReference<List<SqlExample>>() {
};
public List<SqlExample> getSqlExamples() throws IOException {
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
InputStream inputStream = resource.getInputStream();
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
}
public void addEmbeddingStore(List<SqlExample> sqlExamples, String collectionName) {
List<EmbeddingQuery> queries = new ArrayList<>();
for (int i = 0; i < sqlExamples.size(); i++) {
SqlExample sqlExample = sqlExamples.get(i);
String question = sqlExample.getQuestion();
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(sqlExample), String.class, Object.class);
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(String.valueOf(i));
embeddingQuery.setQuery(question);
embeddingQuery.setMetadata(metaDataMap);
queries.add(embeddingQuery);
}
s2EmbeddingStore.addQuery(collectionName, queries);
}
public List<Map<String, String>> retrieverSqlExamples(String queryText, String collectionName, int maxResults) {
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
.queryEmbeddings(null).build();
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(collectionName, retrieveQuery,
maxResults);
List<Map<String, String>> result = new ArrayList<>();
if (CollectionUtils.isEmpty(resultList)) {
return result;
}
for (Retrieval retrieval : resultList.get(0).getRetrieval()) {
if (Objects.nonNull(retrieval.getMetadata()) && !retrieval.getMetadata().isEmpty()) {
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue())));
result.add(convertedMap);
}
}
return result;
}
}

View File

@@ -0,0 +1,65 @@
package com.tencent.supersonic.chat.parser.sql.llm.prompt;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class SqlPromptGenerator {
public String generateSchemaLinkingPrompt(String question, String modelName, List<String> fieldsList,
List<ElementValue> priorSchemaLinks, List<Map<String, String>> exampleList) {
String exampleTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n"
+ "问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schemaLinks}";
List<String> exampleKeys = Arrays.asList("tableName", "fieldsList", "priorSchemaLinks", "question", "analysis",
"schemaLinks");
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
String newCaseTemplate = "Table {tableName}, columns = {fieldsList}, prior_schema_links = {priorSchemaLinks}\n"
+ "问题:{question}\n分析: 让我们一步一步地思考。";
String newCasePrompt = newCaseTemplate.replace("{tableName}", modelName)
.replace("{fieldsList}", fieldsList.toString())
.replace("{priorSchemaLinks}", getPriorSchemaLinks(priorSchemaLinks))
.replace("{question}", question);
String instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links";
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
}
private String getPriorSchemaLinks(List<ElementValue> priorSchemaLinks) {
return priorSchemaLinks.stream()
.map(elementValue -> "'" + elementValue.getFieldName() + "'->" + elementValue.getFieldValue())
.collect(Collectors.joining(",", "[", "]"));
}
public String generateSqlPrompt(String question, String modelName, String schemaLinkStr, String dataDate,
List<Map<String, String>> exampleList) {
List<String> exampleKeys = Arrays.asList("question", "currentDate", "tableName", "schemaLinks", "sql");
String exampleTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n"
+ "Schema_links:{schemaLinks}\nSQL:{sql}";
String sqlExamplePrompt = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
String newCaseTemplate = "问题:{question}\nCurrent_date:{currentDate}\nTable {tableName}\n"
+ "Schema_links:{schemaLinks}\nSQL:";
String newCasePrompt = newCaseTemplate.replace("{question}", question)
.replace("{currentDate}", dataDate)
.replace("{tableName}", modelName)
.replace("{schemaLinks}", schemaLinkStr);
String instruction = "# 根据schema_links为每个问题生成SQL查询语句";
return instruction + InputFormat.SEPERATOR + sqlExamplePrompt + InputFormat.SEPERATOR + newCasePrompt;
}
}

View File

@@ -19,12 +19,13 @@ import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
@@ -39,7 +40,6 @@ import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.context.event.EventListener;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
@Slf4j
@Component
@@ -47,14 +47,10 @@ public class PluginManager {
private EmbeddingConfig embeddingConfig;
private RestTemplate restTemplate;
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
private EmbeddingUtils embeddingUtils;
public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate, EmbeddingUtils embeddingUtils) {
public PluginManager(EmbeddingConfig embeddingConfig) {
this.embeddingConfig = embeddingConfig;
this.restTemplate = restTemplate;
this.embeddingUtils = embeddingUtils;
}
public static List<Plugin> getPluginAgentCanSupport(Integer agentId) {
@@ -133,7 +129,7 @@ public class PluginManager {
embeddingQuery.setQueryId(id);
queries.add(embeddingQuery);
}
embeddingUtils.deleteQuery(presetCollection, queries);
s2EmbeddingStore.deleteQuery(presetCollection, queries);
}
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
@@ -141,7 +137,7 @@ public class PluginManager {
return;
}
String presetCollection = embeddingConfig.getPresetCollection();
embeddingUtils.addQuery(presetCollection, queries);
s2EmbeddingStore.addQuery(presetCollection, queries);
}
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
@@ -150,13 +146,12 @@ public class PluginManager {
public RetrieveQueryResult recognize(String embeddingText) {
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
.queryTextsList(Collections.singletonList(embeddingText))
.build();
List<RetrieveQueryResult> resultList = embeddingUtils.retrieveQuery(embeddingConfig.getPresetCollection(),
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(embeddingConfig.getPresetCollection(),
retrieveQuery, embeddingConfig.getNResult());

View File

@@ -7,13 +7,12 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
@@ -21,6 +20,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.springframework.util.CollectionUtils;
/**
* SimilarMetricQueryResponder fills recommended metrics based on embedding similarity.
@@ -29,6 +29,8 @@ public class SimilarMetricQueryResponder implements QueryResponder {
private static final int METRIC_RECOMMEND_SIZE = 5;
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@Override
public void fillInfo(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
fillSimilarMetric(queryResult.getChatContext());
@@ -46,8 +48,7 @@ public class SimilarMetricQueryResponder implements QueryResponder {
filterCondition.put("type", SchemaElementType.METRIC.name());
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
.filterCondition(filterCondition).queryEmbeddings(null).build();
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, METRIC_RECOMMEND_SIZE + 1);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;
@@ -66,7 +67,8 @@ public class SimilarMetricQueryResponder implements QueryResponder {
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
SchemaElement.class);
if (retrieval.getMetadata().containsKey("modelId")) {
schemaElement.setModel(Long.parseLong(retrieval.getMetadata().get("modelId")));
String modelId = retrieval.getMetadata().get("modelId").toString();
schemaElement.setModel(Long.parseLong(modelId));
}
schemaElement.setOrder(++metricOrder);
parseInfo.getMetrics().add(schemaElement);

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.chat.query.llm.analytics;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
@@ -14,13 +13,11 @@ import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.HashMap;
@@ -31,8 +28,6 @@ import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@@ -43,7 +38,6 @@ public class MetricAnalyzeQuery extends LLMSemanticQuery {
public static final String QUERY_MODE = "METRIC_INTERPRET";
public MetricAnalyzeQuery() {
QueryManager.register(this);
}
@@ -151,23 +145,7 @@ public class MetricAnalyzeQuery extends LLMSemanticQuery {
}
public String fetchInterpret(String queryText, String dataText) {
LLMAnswerReq lLmAnswerReq = new LLMAnswerReq();
lLmAnswerReq.setQueryText(queryText);
lLmAnswerReq.setPluginOutput(dataText);
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String metricAnalyzeQueryCollection = embeddingConfig.getMetricAnalyzeQueryCollection();
String url = String.format("%s/retrieve_query?collection_name=%s", embeddingConfig.getUrl(),
metricAnalyzeQueryCollection);
ResponseEntity<String> responseEntity = embeddingUtils.doRequest(url, JSONObject.toJSONString(lLmAnswerReq),
HttpMethod.POST);
LLMAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLMAnswerResp.class);
if (lLmAnswerResp != null) {
return lLmAnswerResp.getAssistantMessage();
}
return null;
return "";
}
}

View File

@@ -8,11 +8,11 @@ import com.tencent.supersonic.chat.parser.LLMProxy;
import com.tencent.supersonic.chat.parser.sql.llm.ModelResolver;
import com.tencent.supersonic.chat.processor.ParseResultProcessor;
import com.tencent.supersonic.chat.query.QueryResponder;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.core.io.support.SpringFactoriesLoader;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.core.io.support.SpringFactoriesLoader;
public class ComponentFactory {

View File

@@ -4,11 +4,12 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
@@ -36,11 +37,11 @@ public class SolvedQueryManager {
private EmbeddingConfig embeddingConfig;
private EmbeddingUtils embeddingUtils;
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
public SolvedQueryManager(EmbeddingConfig embeddingConfig, EmbeddingUtils embeddingUtils) {
public SolvedQueryManager(EmbeddingConfig embeddingConfig) {
this.embeddingConfig = embeddingConfig;
this.embeddingUtils = embeddingUtils;
}
public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) {
@@ -54,12 +55,12 @@ public class SolvedQueryManager {
embeddingQuery.setQueryId(uniqueId);
embeddingQuery.setQuery(queryText);
Map<String, String> metaData = new HashMap<>();
metaData.put("modelId", String.valueOf(solvedQueryReq.getModelId()));
metaData.put("agentId", String.valueOf(solvedQueryReq.getAgentId()));
Map<String, Object> metaData = new HashMap<>();
metaData.put("modelId", (solvedQueryReq.getModelId()));
metaData.put("agentId", solvedQueryReq.getAgentId());
embeddingQuery.setMetadata(metaData);
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
embeddingUtils.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
s2EmbeddingStore.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
} catch (Exception e) {
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
}
@@ -80,7 +81,7 @@ public class SolvedQueryManager {
.queryTextsList(Lists.newArrayList(queryText))
.filterCondition(filterCondition)
.build();
List<RetrieveQueryResult> resultList = embeddingUtils.retrieveQuery(solvedQueryCollection, retrieveQuery,
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(solvedQueryCollection, retrieveQuery,
solvedQueryResultNum);
log.info("[embedding] recognize result body:{}", resultList);