[improvement][project]Remove unnecessary copy from Request to Context objects.

This commit is contained in:
jerryjzhang
2024-10-27 15:59:49 +08:00
parent 397b527bc6
commit 1842261dfe
14 changed files with 66 additions and 96 deletions

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.api.pojo.request; package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder; import lombok.Builder;
@@ -16,10 +15,8 @@ public class ChatParseReq {
private String queryText; private String queryText;
private Integer chatId; private Integer chatId;
private Integer agentId; private Integer agentId;
private Integer topN = 10;
private User user; private User user;
private QueryFilters queryFilters; private QueryFilters queryFilters;
private boolean saveAnswer = true; private boolean saveAnswer = true;
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private boolean disableLLM = false; private boolean disableLLM = false;
} }

View File

@@ -50,7 +50,7 @@ public class PlainTextExecutor implements ChatQueryExecutor {
} }
String promptStr = String.format(chatApp.getPrompt(), getHistoryInputs(executeContext), String promptStr = String.format(chatApp.getPrompt(), getHistoryInputs(executeContext),
executeContext.getQueryText()); executeContext.getRequest().getQueryText());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
ChatLanguageModel chatLanguageModel = ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatApp.getChatModelConfig()); ModelProvider.getChatModel(chatApp.getChatModelConfig());
@@ -66,8 +66,8 @@ public class PlainTextExecutor implements ChatQueryExecutor {
private String getHistoryInputs(ExecuteContext executeContext) { private String getHistoryInputs(ExecuteContext executeContext) {
StringBuilder historyInput = new StringBuilder(); StringBuilder historyInput = new StringBuilder();
List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5); List<QueryResp> queryResps = getHistoryQueries(executeContext.getRequest().getChatId(), 5);
queryResps.stream().forEach(p -> { queryResps.forEach(p -> {
historyInput.append(p.getQueryText()); historyInput.append(p.getQueryText());
historyInput.append(";"); historyInput.append(";");

View File

@@ -48,9 +48,9 @@ public class SqlExecutor implements ChatQueryExecutor {
.agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING) .agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING)
.question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo()) .question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo())
.dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql()) .dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql())
.createdBy(executeContext.getUser().getName()) .createdBy(executeContext.getRequest().getUser().getName())
.updatedBy(executeContext.getUser().getName()).createdAt(new Date()) .updatedBy(executeContext.getRequest().getUser().getName())
.build()); .createdAt(new Date()).build());
} }
} }
@@ -62,7 +62,8 @@ public class SqlExecutor implements ChatQueryExecutor {
SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class); SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class);
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx = chatContextService.getOrCreateContext(executeContext.getChatId()); ChatContext chatCtx =
chatContextService.getOrCreateContext(executeContext.getRequest().getChatId());
SemanticParseInfo parseInfo = executeContext.getParseInfo(); SemanticParseInfo parseInfo = executeContext.getParseInfo();
if (Objects.isNull(parseInfo.getSqlInfo()) if (Objects.isNull(parseInfo.getSqlInfo())
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) { || StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
@@ -79,7 +80,8 @@ public class SqlExecutor implements ChatQueryExecutor {
queryResult.setChatContext(parseInfo); queryResult.setChatContext(parseInfo);
queryResult.setQueryMode(parseInfo.getQueryMode()); queryResult.setQueryMode(parseInfo.getQueryMode());
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime); queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, executeContext.getUser()); SemanticQueryResp queryResp =
semanticLayer.queryByReq(sqlReq, executeContext.getRequest().getUser());
if (queryResp != null) { if (queryResp != null) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization()); queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
queryResult.setQuerySql(queryResp.getSql()); queryResult.setQuerySql(queryResp.getSql());

View File

@@ -91,16 +91,24 @@ public class NL2SQLParser implements ChatQueryParser {
@Override @Override
public void parse(ParseContext parseContext, ParseResp parseResp) { public void parse(ParseContext parseContext, ParseResp parseResp) {
if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) { if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
return;
}
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
if (Objects.isNull(queryNLReq)) {
return; return;
} }
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
if (!parseContext.isDisableLLM()) { if (!parseContext.getRequest().isDisableLLM()) {
processMultiTurn(parseContext); processMultiTurn(parseContext);
} }
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext, chatCtx);
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx =
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
if (chatCtx != null) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
}
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq); addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
@@ -108,7 +116,7 @@ public class NL2SQLParser implements ChatQueryParser {
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) { if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else { } else {
if (!parseContext.isDisableLLM()) { if (!parseContext.getRequest().isDisableLLM()) {
parseResp.setErrorMsg(rewriteErrorMessage(parseContext, parseResp.setErrorMsg(rewriteErrorMessage(parseContext,
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars())); text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars()));
} }
@@ -119,16 +127,6 @@ public class NL2SQLParser implements ChatQueryParser {
formatParseResult(parseResp); formatParseResult(parseResp);
} }
private boolean checkSkip(ParseResp parseResp) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
for (SemanticParseInfo semanticParseInfo : selectedParses) {
if (semanticParseInfo.getScore() >= parseResp.getQueryText().length()) {
return true;
}
}
return false;
}
private void formatParseResult(ParseResp parseResp) { private void formatParseResult(ParseResp parseResp) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses(); List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
for (SemanticParseInfo parseInfo : selectedParses) { for (SemanticParseInfo parseInfo : selectedParses) {
@@ -182,7 +180,8 @@ public class NL2SQLParser implements ChatQueryParser {
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
MapResp currentMapResult = chatLayerService.map(queryNLReq); MapResp currentMapResult = chatLayerService.map(queryNLReq);
List<QueryResp> historyQueries = getHistoryQueries(parseContext.getChatId(), 1); List<QueryResp> historyQueries =
getHistoryQueries(parseContext.getRequest().getChatId(), 1);
if (historyQueries.isEmpty()) { if (historyQueries.isEmpty()) {
return; return;
} }
@@ -208,7 +207,7 @@ public class NL2SQLParser implements ChatQueryParser {
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String rewrittenQuery = response.content().text(); String rewrittenQuery = response.content().text();
keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response); keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response);
parseContext.setQueryText(rewrittenQuery); parseContext.getRequest().setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(), log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(),
currentMapResult.getQueryText(), rewrittenQuery); currentMapResult.getQueryText(), rewrittenQuery);
} }
@@ -222,7 +221,7 @@ public class NL2SQLParser implements ChatQueryParser {
} }
Map<String, Object> variables = new HashMap<>(); Map<String, Object> variables = new HashMap<>();
variables.put("user_question", parseContext.getQueryText()); variables.put("user_question", parseContext.getRequest().getQueryText());
variables.put("system_message", errMsg); variables.put("system_message", errMsg);
StringBuilder exampleStr = new StringBuilder(); StringBuilder exampleStr = new StringBuilder();
@@ -286,7 +285,7 @@ public class NL2SQLParser implements ChatQueryParser {
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId); String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
int exemplarRecallNumber = int exemplarRecallNumber =
Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER)); Integer.parseInt(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
List<Text2SQLExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName, List<Text2SQLExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
queryNLReq.getQueryText(), exemplarRecallNumber); queryNLReq.getQueryText(), exemplarRecallNumber);
queryNLReq.getDynamicExemplars().addAll(exemplars); queryNLReq.getDynamicExemplars().addAll(exemplars);

View File

@@ -72,7 +72,7 @@ public abstract class PluginRecognizer {
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin, protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
ParseContext parseContext, SchemaMapInfo mapInfo, double distance) { ParseContext parseContext, SchemaMapInfo mapInfo, double distance) {
List<SchemaElementMatch> schemaElementMatches = mapInfo.getMatchedElements(dataSetId); List<SchemaElementMatch> schemaElementMatches = mapInfo.getMatchedElements(dataSetId);
QueryFilters queryFilters = parseContext.getQueryFilters(); QueryFilters queryFilters = parseContext.getRequest().getQueryFilters();
if (schemaElementMatches == null) { if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList(); schemaElementMatches = Lists.newArrayList();
} }
@@ -86,7 +86,7 @@ public abstract class PluginRecognizer {
pluginParseResult.setPlugin(plugin); pluginParseResult.setPlugin(plugin);
pluginParseResult.setQueryFilters(queryFilters); pluginParseResult.setQueryFilters(queryFilters);
pluginParseResult.setDistance(distance); pluginParseResult.setDistance(distance);
pluginParseResult.setQueryText(parseContext.getQueryText()); pluginParseResult.setQueryText(parseContext.getRequest().getQueryText());
properties.put(Constants.CONTEXT, pluginParseResult); properties.put(Constants.CONTEXT, pluginParseResult);
properties.put("type", "plugin"); properties.put("type", "plugin");
properties.put("name", plugin.getName()); properties.put("name", plugin.getName());

View File

@@ -30,7 +30,7 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
} }
public PluginRecallResult recallPlugin(ParseContext parseContext) { public PluginRecallResult recallPlugin(ParseContext parseContext) {
String text = parseContext.getQueryText(); String text = parseContext.getRequest().getQueryText();
List<Retrieval> embeddingRetrievals = embeddingRecall(text); List<Retrieval> embeddingRetrievals = embeddingRecall(text);
if (CollectionUtils.isEmpty(embeddingRetrievals)) { if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return null; return null;
@@ -52,7 +52,7 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
} }
plugin.setParseMode(ParseMode.EMBEDDING_RECALL); plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double similarity = embeddingRetrieval.getSimilarity(); double similarity = embeddingRetrieval.getSimilarity();
double score = parseContext.getQueryText().length() * similarity; double score = parseContext.getRequest().getQueryText().length() * similarity;
return PluginRecallResult.builder().plugin(plugin).dataSetIds(dataSetList) return PluginRecallResult.builder().plugin(plugin).dataSetIds(dataSetList)
.score(score).distance(similarity).build(); .score(score).distance(similarity).build();
} }

View File

@@ -1,17 +1,17 @@
package com.tencent.supersonic.chat.server.pojo; package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data; import lombok.Data;
@Data @Data
public class ExecuteContext { public class ExecuteContext {
private User user; private ChatExecuteReq request;
private String queryText;
private Agent agent; private Agent agent;
private Integer chatId;
private Long queryId;
private boolean saveAnswer;
private SemanticParseInfo parseInfo; private SemanticParseInfo parseInfo;
public ExecuteContext(ChatExecuteReq request) {
this.request = request;
}
} }

View File

@@ -1,19 +1,17 @@
package com.tencent.supersonic.chat.server.pojo; package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import lombok.Data; import lombok.Data;
@Data @Data
public class ParseContext { public class ParseContext {
private User user; private ChatParseReq request;
private String queryText;
private Agent agent; private Agent agent;
private Integer chatId;
private QueryFilters queryFilters; public ParseContext(ChatParseReq request) {
private boolean saveAnswer = true; this.request = request;
private boolean disableLLM = false; }
public boolean enableNL2SQL() { public boolean enableNL2SQL() {
if (agent == null) { if (agent == null) {

View File

@@ -48,7 +48,7 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
} }
Map<String, Object> variable = new HashMap<>(); Map<String, Object> variable = new HashMap<>();
variable.put("question", executeContext.getQueryText()); variable.put("question", executeContext.getRequest().getQueryText());
variable.put("data", queryResult.getTextResult()); variable.put("data", queryResult.getTextResult());
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable); Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable);

View File

@@ -67,8 +67,8 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
|| !QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())) { || !QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())) {
return; return;
} }
AggregateInfo aggregateInfo = AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getRequest().getUser(),
getAggregateInfo(executeContext.getUser(), semanticParseInfo, queryResult); semanticParseInfo, queryResult);
queryResult.setAggregateInfo(aggregateInfo); queryResult.setAggregateInfo(aggregateInfo);
} }

View File

@@ -30,8 +30,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
@SneakyThrows @SneakyThrows
private void doProcess(ParseResp parseResp, ParseContext parseContext) { private void doProcess(ParseResp parseResp, ParseContext parseContext) {
Long queryId = parseResp.getQueryId(); Long queryId = parseResp.getQueryId();
List<SimilarQueryRecallResp> solvedQueries = List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(
getSimilarQueries(parseContext.getQueryText(), parseContext.getAgent().getId()); parseContext.getRequest().getQueryText(), parseContext.getAgent().getId());
ChatQueryDO chatQueryDO = getChatQuery(queryId); ChatQueryDO chatQueryDO = getChatQuery(queryId);
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries)); chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
updateChatQuery(chatQueryDO); updateChatQuery(chatQueryDO);

View File

@@ -24,8 +24,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.service.ChatModelService;
import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
@@ -86,8 +84,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
private SemanticLayerService semanticLayerService; private SemanticLayerService semanticLayerService;
@Autowired @Autowired
private AgentService agentService; private AgentService agentService;
@Autowired
private ChatModelService chatModelService;
private final List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers(); private final List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
private final List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors(); private final List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
@@ -120,7 +116,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
processor.process(parseContext, parseResp); processor.process(parseContext, parseResp);
} }
chatParseReq.setQueryText(parseContext.getQueryText()); chatParseReq.setQueryText(parseContext.getRequest().getQueryText());
chatManageService.batchAddParse(chatParseReq, parseResp); chatManageService.batchAddParse(chatParseReq, parseResp);
chatManageService.updateParseCostTime(parseResp); chatManageService.updateParseCostTime(parseResp);
return parseResp; return parseResp;
@@ -168,16 +164,14 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private ParseContext buildParseContext(ChatParseReq chatParseReq) { private ParseContext buildParseContext(ChatParseReq chatParseReq) {
ParseContext parseContext = new ParseContext(); ParseContext parseContext = new ParseContext(chatParseReq);
BeanMapper.mapper(chatParseReq, parseContext);
Agent agent = agentService.getAgent(chatParseReq.getAgentId()); Agent agent = agentService.getAgent(chatParseReq.getAgentId());
parseContext.setAgent(agent); parseContext.setAgent(agent);
return parseContext; return parseContext;
} }
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) { private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
ExecuteContext executeContext = new ExecuteContext(); ExecuteContext executeContext = new ExecuteContext(chatExecuteReq);
BeanMapper.mapper(chatExecuteReq, executeContext);
SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatExecuteReq.getQueryId(), SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatExecuteReq.getQueryId(),
chatExecuteReq.getParseId()); chatExecuteReq.getParseId());
Agent agent = agentService.getAgent(chatExecuteReq.getAgentId()); Agent agent = agentService.getAgent(chatExecuteReq.getAgentId());
@@ -443,14 +437,14 @@ public class ChatQueryServiceImpl implements ChatQueryService {
if (CollectionUtils.isEmpty(valueList)) { if (CollectionUtils.isEmpty(valueList)) {
return; return;
} }
valueList.stream().forEach(o -> { valueList.forEach(o -> {
StringValue stringValue = new StringValue(o); StringValue stringValue = new StringValue(o);
parenthesedExpressionList.add(stringValue); parenthesedExpressionList.add(stringValue);
}); });
inExpression.setLeftExpression(column); inExpression.setLeftExpression(column);
inExpression.setRightExpression(parenthesedExpressionList); inExpression.setRightExpression(parenthesedExpressionList);
addConditions.add(inExpression); addConditions.add(inExpression);
contextMetricFilters.stream().forEach(o -> { contextMetricFilters.forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) { if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue()); o.setValue(dslQueryFilter.getValue());
o.setOperator(dslQueryFilter.getOperator()); o.setOperator(dslQueryFilter.getOperator());
@@ -480,7 +474,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
comparisonExpression.setRightExpression(stringValue); comparisonExpression.setRightExpression(stringValue);
} }
addConditions.add(comparisonExpression); addConditions.add(comparisonExpression);
contextMetricFilters.stream().forEach(o -> { contextMetricFilters.forEach(o -> {
if (o.getName().equals(dslQueryFilter.getName())) { if (o.getName().equals(dslQueryFilter.getName())) {
o.setValue(dslQueryFilter.getValue()); o.setValue(dslQueryFilter.getValue());
o.setOperator(dslQueryFilter.getOperator()); o.setOperator(dslQueryFilter.getOperator());

View File

@@ -1,44 +1,24 @@
package com.tencent.supersonic.chat.server.util; package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import org.apache.commons.collections.MapUtils;
import java.util.Objects;
public class QueryReqConverter { public class QueryReqConverter {
public static QueryNLReq buildQueryNLReq(ParseContext parseContext) { public static QueryNLReq buildQueryNLReq(ParseContext parseContext) {
return buildQueryNLReq(parseContext, null); if (parseContext.getAgent() == null) {
return null;
} }
public static QueryNLReq buildQueryNLReq(ParseContext parseContext, ChatContext chatCtx) {
QueryNLReq queryNLReq = new QueryNLReq(); QueryNLReq queryNLReq = new QueryNLReq();
BeanMapper.mapper(parseContext, queryNLReq); BeanMapper.mapper(parseContext.getRequest(), queryNLReq);
Agent agent = parseContext.getAgent(); queryNLReq.setText2SQLType(parseContext.getRequest().isDisableLLM() ? Text2SQLType.ONLY_RULE
if (agent == null) { : Text2SQLType.RULE_AND_LLM);
return queryNLReq; queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds());
}
if (parseContext.isDisableLLM()) {
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
} else {
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
}
queryNLReq.setDataSetIds(agent.getDataSetIds());
if (Objects.nonNull(queryNLReq.getMapInfo())
&& MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
queryNLReq.setMapInfo(queryNLReq.getMapInfo());
}
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig()); queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
if (chatCtx != null) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
}
return queryNLReq; return queryNLReq;
} }
} }

View File

@@ -70,8 +70,8 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner {
RetrieveQuery.builder().queryTextsList(Lists.newArrayList(query)).build(); RetrieveQuery.builder().queryTextsList(Lists.newArrayList(query)).build();
List<RetrieveQueryResult> results = List<RetrieveQueryResult> results =
embeddingService.retrieveQuery(collection, retrieveQuery, num); embeddingService.retrieveQuery(collection, retrieveQuery, num);
results.stream().forEach(ret -> { results.forEach(ret -> {
ret.getRetrieval().stream().forEach(r -> { ret.getRetrieval().forEach(r -> {
exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class)); exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class));
}); });
}); });