mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
[improvement][project]Remove unnecessary copy from Request to Context objects.
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.api.pojo.request;
|
package com.tencent.supersonic.chat.api.pojo.request;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.User;
|
import com.tencent.supersonic.common.pojo.User;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
@@ -16,10 +15,8 @@ public class ChatParseReq {
|
|||||||
private String queryText;
|
private String queryText;
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private Integer agentId;
|
private Integer agentId;
|
||||||
private Integer topN = 10;
|
|
||||||
private User user;
|
private User user;
|
||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer = true;
|
||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
|
||||||
private boolean disableLLM = false;
|
private boolean disableLLM = false;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ public class PlainTextExecutor implements ChatQueryExecutor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
String promptStr = String.format(chatApp.getPrompt(), getHistoryInputs(executeContext),
|
String promptStr = String.format(chatApp.getPrompt(), getHistoryInputs(executeContext),
|
||||||
executeContext.getQueryText());
|
executeContext.getRequest().getQueryText());
|
||||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||||
ChatLanguageModel chatLanguageModel =
|
ChatLanguageModel chatLanguageModel =
|
||||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||||
@@ -66,8 +66,8 @@ public class PlainTextExecutor implements ChatQueryExecutor {
|
|||||||
|
|
||||||
private String getHistoryInputs(ExecuteContext executeContext) {
|
private String getHistoryInputs(ExecuteContext executeContext) {
|
||||||
StringBuilder historyInput = new StringBuilder();
|
StringBuilder historyInput = new StringBuilder();
|
||||||
List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5);
|
List<QueryResp> queryResps = getHistoryQueries(executeContext.getRequest().getChatId(), 5);
|
||||||
queryResps.stream().forEach(p -> {
|
queryResps.forEach(p -> {
|
||||||
historyInput.append(p.getQueryText());
|
historyInput.append(p.getQueryText());
|
||||||
historyInput.append(";");
|
historyInput.append(";");
|
||||||
|
|
||||||
|
|||||||
@@ -48,9 +48,9 @@ public class SqlExecutor implements ChatQueryExecutor {
|
|||||||
.agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING)
|
.agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING)
|
||||||
.question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo())
|
.question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo())
|
||||||
.dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql())
|
.dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql())
|
||||||
.createdBy(executeContext.getUser().getName())
|
.createdBy(executeContext.getRequest().getUser().getName())
|
||||||
.updatedBy(executeContext.getUser().getName()).createdAt(new Date())
|
.updatedBy(executeContext.getRequest().getUser().getName())
|
||||||
.build());
|
.createdAt(new Date()).build());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +62,8 @@ public class SqlExecutor implements ChatQueryExecutor {
|
|||||||
SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class);
|
SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class);
|
||||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||||
|
|
||||||
ChatContext chatCtx = chatContextService.getOrCreateContext(executeContext.getChatId());
|
ChatContext chatCtx =
|
||||||
|
chatContextService.getOrCreateContext(executeContext.getRequest().getChatId());
|
||||||
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
||||||
if (Objects.isNull(parseInfo.getSqlInfo())
|
if (Objects.isNull(parseInfo.getSqlInfo())
|
||||||
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||||
@@ -79,7 +80,8 @@ public class SqlExecutor implements ChatQueryExecutor {
|
|||||||
queryResult.setChatContext(parseInfo);
|
queryResult.setChatContext(parseInfo);
|
||||||
queryResult.setQueryMode(parseInfo.getQueryMode());
|
queryResult.setQueryMode(parseInfo.getQueryMode());
|
||||||
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
|
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
|
||||||
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, executeContext.getUser());
|
SemanticQueryResp queryResp =
|
||||||
|
semanticLayer.queryByReq(sqlReq, executeContext.getRequest().getUser());
|
||||||
if (queryResp != null) {
|
if (queryResp != null) {
|
||||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||||
queryResult.setQuerySql(queryResp.getSql());
|
queryResult.setQuerySql(queryResp.getSql());
|
||||||
|
|||||||
@@ -91,16 +91,24 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||||
if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||||
|
if (Objects.isNull(queryNLReq)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
|
||||||
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
|
|
||||||
|
|
||||||
if (!parseContext.isDisableLLM()) {
|
if (!parseContext.getRequest().isDisableLLM()) {
|
||||||
processMultiTurn(parseContext);
|
processMultiTurn(parseContext);
|
||||||
}
|
}
|
||||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext, chatCtx);
|
|
||||||
|
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||||
|
ChatContext chatCtx =
|
||||||
|
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
|
||||||
|
if (chatCtx != null) {
|
||||||
|
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||||
|
}
|
||||||
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
||||||
|
|
||||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||||
@@ -108,7 +116,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
||||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||||
} else {
|
} else {
|
||||||
if (!parseContext.isDisableLLM()) {
|
if (!parseContext.getRequest().isDisableLLM()) {
|
||||||
parseResp.setErrorMsg(rewriteErrorMessage(parseContext,
|
parseResp.setErrorMsg(rewriteErrorMessage(parseContext,
|
||||||
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars()));
|
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars()));
|
||||||
}
|
}
|
||||||
@@ -119,16 +127,6 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
formatParseResult(parseResp);
|
formatParseResult(parseResp);
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean checkSkip(ParseResp parseResp) {
|
|
||||||
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
|
||||||
for (SemanticParseInfo semanticParseInfo : selectedParses) {
|
|
||||||
if (semanticParseInfo.getScore() >= parseResp.getQueryText().length()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void formatParseResult(ParseResp parseResp) {
|
private void formatParseResult(ParseResp parseResp) {
|
||||||
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
||||||
for (SemanticParseInfo parseInfo : selectedParses) {
|
for (SemanticParseInfo parseInfo : selectedParses) {
|
||||||
@@ -182,7 +180,8 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||||
MapResp currentMapResult = chatLayerService.map(queryNLReq);
|
MapResp currentMapResult = chatLayerService.map(queryNLReq);
|
||||||
|
|
||||||
List<QueryResp> historyQueries = getHistoryQueries(parseContext.getChatId(), 1);
|
List<QueryResp> historyQueries =
|
||||||
|
getHistoryQueries(parseContext.getRequest().getChatId(), 1);
|
||||||
if (historyQueries.isEmpty()) {
|
if (historyQueries.isEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -208,7 +207,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
String rewrittenQuery = response.content().text();
|
String rewrittenQuery = response.content().text();
|
||||||
keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response);
|
keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response);
|
||||||
parseContext.setQueryText(rewrittenQuery);
|
parseContext.getRequest().setQueryText(rewrittenQuery);
|
||||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(),
|
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(),
|
||||||
currentMapResult.getQueryText(), rewrittenQuery);
|
currentMapResult.getQueryText(), rewrittenQuery);
|
||||||
}
|
}
|
||||||
@@ -222,7 +221,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Map<String, Object> variables = new HashMap<>();
|
Map<String, Object> variables = new HashMap<>();
|
||||||
variables.put("user_question", parseContext.getQueryText());
|
variables.put("user_question", parseContext.getRequest().getQueryText());
|
||||||
variables.put("system_message", errMsg);
|
variables.put("system_message", errMsg);
|
||||||
|
|
||||||
StringBuilder exampleStr = new StringBuilder();
|
StringBuilder exampleStr = new StringBuilder();
|
||||||
@@ -286,7 +285,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
|
||||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||||
int exemplarRecallNumber =
|
int exemplarRecallNumber =
|
||||||
Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
Integer.parseInt(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||||
List<Text2SQLExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
List<Text2SQLExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
|
||||||
queryNLReq.getQueryText(), exemplarRecallNumber);
|
queryNLReq.getQueryText(), exemplarRecallNumber);
|
||||||
queryNLReq.getDynamicExemplars().addAll(exemplars);
|
queryNLReq.getDynamicExemplars().addAll(exemplars);
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ public abstract class PluginRecognizer {
|
|||||||
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
|
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
|
||||||
ParseContext parseContext, SchemaMapInfo mapInfo, double distance) {
|
ParseContext parseContext, SchemaMapInfo mapInfo, double distance) {
|
||||||
List<SchemaElementMatch> schemaElementMatches = mapInfo.getMatchedElements(dataSetId);
|
List<SchemaElementMatch> schemaElementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||||
QueryFilters queryFilters = parseContext.getQueryFilters();
|
QueryFilters queryFilters = parseContext.getRequest().getQueryFilters();
|
||||||
if (schemaElementMatches == null) {
|
if (schemaElementMatches == null) {
|
||||||
schemaElementMatches = Lists.newArrayList();
|
schemaElementMatches = Lists.newArrayList();
|
||||||
}
|
}
|
||||||
@@ -86,7 +86,7 @@ public abstract class PluginRecognizer {
|
|||||||
pluginParseResult.setPlugin(plugin);
|
pluginParseResult.setPlugin(plugin);
|
||||||
pluginParseResult.setQueryFilters(queryFilters);
|
pluginParseResult.setQueryFilters(queryFilters);
|
||||||
pluginParseResult.setDistance(distance);
|
pluginParseResult.setDistance(distance);
|
||||||
pluginParseResult.setQueryText(parseContext.getQueryText());
|
pluginParseResult.setQueryText(parseContext.getRequest().getQueryText());
|
||||||
properties.put(Constants.CONTEXT, pluginParseResult);
|
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||||
properties.put("type", "plugin");
|
properties.put("type", "plugin");
|
||||||
properties.put("name", plugin.getName());
|
properties.put("name", plugin.getName());
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public PluginRecallResult recallPlugin(ParseContext parseContext) {
|
public PluginRecallResult recallPlugin(ParseContext parseContext) {
|
||||||
String text = parseContext.getQueryText();
|
String text = parseContext.getRequest().getQueryText();
|
||||||
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||||
return null;
|
return null;
|
||||||
@@ -52,7 +52,7 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
|||||||
}
|
}
|
||||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||||
double similarity = embeddingRetrieval.getSimilarity();
|
double similarity = embeddingRetrieval.getSimilarity();
|
||||||
double score = parseContext.getQueryText().length() * similarity;
|
double score = parseContext.getRequest().getQueryText().length() * similarity;
|
||||||
return PluginRecallResult.builder().plugin(plugin).dataSetIds(dataSetList)
|
return PluginRecallResult.builder().plugin(plugin).dataSetIds(dataSetList)
|
||||||
.score(score).distance(similarity).build();
|
.score(score).distance(similarity).build();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
package com.tencent.supersonic.chat.server.pojo;
|
package com.tencent.supersonic.chat.server.pojo;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.common.pojo.User;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ExecuteContext {
|
public class ExecuteContext {
|
||||||
private User user;
|
private ChatExecuteReq request;
|
||||||
private String queryText;
|
|
||||||
private Agent agent;
|
private Agent agent;
|
||||||
private Integer chatId;
|
|
||||||
private Long queryId;
|
|
||||||
private boolean saveAnswer;
|
|
||||||
private SemanticParseInfo parseInfo;
|
private SemanticParseInfo parseInfo;
|
||||||
|
|
||||||
|
public ExecuteContext(ChatExecuteReq request) {
|
||||||
|
this.request = request;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,17 @@
|
|||||||
package com.tencent.supersonic.chat.server.pojo;
|
package com.tencent.supersonic.chat.server.pojo;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.common.pojo.User;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ParseContext {
|
public class ParseContext {
|
||||||
private User user;
|
private ChatParseReq request;
|
||||||
private String queryText;
|
|
||||||
private Agent agent;
|
private Agent agent;
|
||||||
private Integer chatId;
|
|
||||||
private QueryFilters queryFilters;
|
public ParseContext(ChatParseReq request) {
|
||||||
private boolean saveAnswer = true;
|
this.request = request;
|
||||||
private boolean disableLLM = false;
|
}
|
||||||
|
|
||||||
public boolean enableNL2SQL() {
|
public boolean enableNL2SQL() {
|
||||||
if (agent == null) {
|
if (agent == null) {
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Map<String, Object> variable = new HashMap<>();
|
Map<String, Object> variable = new HashMap<>();
|
||||||
variable.put("question", executeContext.getQueryText());
|
variable.put("question", executeContext.getRequest().getQueryText());
|
||||||
variable.put("data", queryResult.getTextResult());
|
variable.put("data", queryResult.getTextResult());
|
||||||
|
|
||||||
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable);
|
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable);
|
||||||
|
|||||||
@@ -67,8 +67,8 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
|
|||||||
|| !QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())) {
|
|| !QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
AggregateInfo aggregateInfo =
|
AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getRequest().getUser(),
|
||||||
getAggregateInfo(executeContext.getUser(), semanticParseInfo, queryResult);
|
semanticParseInfo, queryResult);
|
||||||
queryResult.setAggregateInfo(aggregateInfo);
|
queryResult.setAggregateInfo(aggregateInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -30,8 +30,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
|||||||
@SneakyThrows
|
@SneakyThrows
|
||||||
private void doProcess(ParseResp parseResp, ParseContext parseContext) {
|
private void doProcess(ParseResp parseResp, ParseContext parseContext) {
|
||||||
Long queryId = parseResp.getQueryId();
|
Long queryId = parseResp.getQueryId();
|
||||||
List<SimilarQueryRecallResp> solvedQueries =
|
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(
|
||||||
getSimilarQueries(parseContext.getQueryText(), parseContext.getAgent().getId());
|
parseContext.getRequest().getQueryText(), parseContext.getAgent().getId());
|
||||||
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
||||||
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
|
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
|
||||||
updateChatQuery(chatQueryDO);
|
updateChatQuery(chatQueryDO);
|
||||||
|
|||||||
@@ -24,8 +24,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
|||||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.common.pojo.User;
|
import com.tencent.supersonic.common.pojo.User;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
import com.tencent.supersonic.common.service.ChatModelService;
|
|
||||||
import com.tencent.supersonic.common.util.BeanMapper;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
import com.tencent.supersonic.common.util.DateUtils;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
@@ -86,8 +84,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
private SemanticLayerService semanticLayerService;
|
private SemanticLayerService semanticLayerService;
|
||||||
@Autowired
|
@Autowired
|
||||||
private AgentService agentService;
|
private AgentService agentService;
|
||||||
@Autowired
|
|
||||||
private ChatModelService chatModelService;
|
|
||||||
|
|
||||||
private final List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
private final List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||||
private final List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
|
private final List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
|
||||||
@@ -120,7 +116,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
processor.process(parseContext, parseResp);
|
processor.process(parseContext, parseResp);
|
||||||
}
|
}
|
||||||
|
|
||||||
chatParseReq.setQueryText(parseContext.getQueryText());
|
chatParseReq.setQueryText(parseContext.getRequest().getQueryText());
|
||||||
chatManageService.batchAddParse(chatParseReq, parseResp);
|
chatManageService.batchAddParse(chatParseReq, parseResp);
|
||||||
chatManageService.updateParseCostTime(parseResp);
|
chatManageService.updateParseCostTime(parseResp);
|
||||||
return parseResp;
|
return parseResp;
|
||||||
@@ -168,16 +164,14 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private ParseContext buildParseContext(ChatParseReq chatParseReq) {
|
private ParseContext buildParseContext(ChatParseReq chatParseReq) {
|
||||||
ParseContext parseContext = new ParseContext();
|
ParseContext parseContext = new ParseContext(chatParseReq);
|
||||||
BeanMapper.mapper(chatParseReq, parseContext);
|
|
||||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||||
parseContext.setAgent(agent);
|
parseContext.setAgent(agent);
|
||||||
return parseContext;
|
return parseContext;
|
||||||
}
|
}
|
||||||
|
|
||||||
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||||
ExecuteContext executeContext = new ExecuteContext();
|
ExecuteContext executeContext = new ExecuteContext(chatExecuteReq);
|
||||||
BeanMapper.mapper(chatExecuteReq, executeContext);
|
|
||||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatExecuteReq.getQueryId(),
|
SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatExecuteReq.getQueryId(),
|
||||||
chatExecuteReq.getParseId());
|
chatExecuteReq.getParseId());
|
||||||
Agent agent = agentService.getAgent(chatExecuteReq.getAgentId());
|
Agent agent = agentService.getAgent(chatExecuteReq.getAgentId());
|
||||||
@@ -443,14 +437,14 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
if (CollectionUtils.isEmpty(valueList)) {
|
if (CollectionUtils.isEmpty(valueList)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
valueList.stream().forEach(o -> {
|
valueList.forEach(o -> {
|
||||||
StringValue stringValue = new StringValue(o);
|
StringValue stringValue = new StringValue(o);
|
||||||
parenthesedExpressionList.add(stringValue);
|
parenthesedExpressionList.add(stringValue);
|
||||||
});
|
});
|
||||||
inExpression.setLeftExpression(column);
|
inExpression.setLeftExpression(column);
|
||||||
inExpression.setRightExpression(parenthesedExpressionList);
|
inExpression.setRightExpression(parenthesedExpressionList);
|
||||||
addConditions.add(inExpression);
|
addConditions.add(inExpression);
|
||||||
contextMetricFilters.stream().forEach(o -> {
|
contextMetricFilters.forEach(o -> {
|
||||||
if (o.getName().equals(dslQueryFilter.getName())) {
|
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||||
o.setValue(dslQueryFilter.getValue());
|
o.setValue(dslQueryFilter.getValue());
|
||||||
o.setOperator(dslQueryFilter.getOperator());
|
o.setOperator(dslQueryFilter.getOperator());
|
||||||
@@ -480,7 +474,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
comparisonExpression.setRightExpression(stringValue);
|
comparisonExpression.setRightExpression(stringValue);
|
||||||
}
|
}
|
||||||
addConditions.add(comparisonExpression);
|
addConditions.add(comparisonExpression);
|
||||||
contextMetricFilters.stream().forEach(o -> {
|
contextMetricFilters.forEach(o -> {
|
||||||
if (o.getName().equals(dslQueryFilter.getName())) {
|
if (o.getName().equals(dslQueryFilter.getName())) {
|
||||||
o.setValue(dslQueryFilter.getValue());
|
o.setValue(dslQueryFilter.getValue());
|
||||||
o.setOperator(dslQueryFilter.getOperator());
|
o.setOperator(dslQueryFilter.getOperator());
|
||||||
|
|||||||
@@ -1,44 +1,24 @@
|
|||||||
package com.tencent.supersonic.chat.server.util;
|
package com.tencent.supersonic.chat.server.util;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.common.util.BeanMapper;
|
import com.tencent.supersonic.common.util.BeanMapper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||||
import org.apache.commons.collections.MapUtils;
|
|
||||||
|
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
public class QueryReqConverter {
|
public class QueryReqConverter {
|
||||||
|
|
||||||
public static QueryNLReq buildQueryNLReq(ParseContext parseContext) {
|
public static QueryNLReq buildQueryNLReq(ParseContext parseContext) {
|
||||||
return buildQueryNLReq(parseContext, null);
|
if (parseContext.getAgent() == null) {
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static QueryNLReq buildQueryNLReq(ParseContext parseContext, ChatContext chatCtx) {
|
|
||||||
QueryNLReq queryNLReq = new QueryNLReq();
|
QueryNLReq queryNLReq = new QueryNLReq();
|
||||||
BeanMapper.mapper(parseContext, queryNLReq);
|
BeanMapper.mapper(parseContext.getRequest(), queryNLReq);
|
||||||
Agent agent = parseContext.getAgent();
|
queryNLReq.setText2SQLType(parseContext.getRequest().isDisableLLM() ? Text2SQLType.ONLY_RULE
|
||||||
if (agent == null) {
|
: Text2SQLType.RULE_AND_LLM);
|
||||||
return queryNLReq;
|
queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds());
|
||||||
}
|
|
||||||
|
|
||||||
if (parseContext.isDisableLLM()) {
|
|
||||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
|
||||||
} else {
|
|
||||||
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
|
||||||
}
|
|
||||||
|
|
||||||
queryNLReq.setDataSetIds(agent.getDataSetIds());
|
|
||||||
if (Objects.nonNull(queryNLReq.getMapInfo())
|
|
||||||
&& MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
|
|
||||||
queryNLReq.setMapInfo(queryNLReq.getMapInfo());
|
|
||||||
}
|
|
||||||
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
|
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
|
||||||
if (chatCtx != null) {
|
|
||||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
|
||||||
}
|
|
||||||
return queryNLReq;
|
return queryNLReq;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -70,8 +70,8 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
|
|||||||
RetrieveQuery.builder().queryTextsList(Lists.newArrayList(query)).build();
|
RetrieveQuery.builder().queryTextsList(Lists.newArrayList(query)).build();
|
||||||
List<RetrieveQueryResult> results =
|
List<RetrieveQueryResult> results =
|
||||||
embeddingService.retrieveQuery(collection, retrieveQuery, num);
|
embeddingService.retrieveQuery(collection, retrieveQuery, num);
|
||||||
results.stream().forEach(ret -> {
|
results.forEach(ret -> {
|
||||||
ret.getRetrieval().stream().forEach(r -> {
|
ret.getRetrieval().forEach(r -> {
|
||||||
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class));
|
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class));
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user