From 1842261dfe47605e290c714aff2c33bc675ce8c4 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Sun, 27 Oct 2024 15:59:49 +0800 Subject: [PATCH] [improvement][project]Remove unnecessary copy from `Request` to `Context` objects. --- .../chat/api/pojo/request/ChatParseReq.java | 3 -- .../server/executor/PlainTextExecutor.java | 6 +-- .../chat/server/executor/SqlExecutor.java | 12 +++--- .../chat/server/parser/NL2SQLParser.java | 39 +++++++++---------- .../plugin/recognize/PluginRecognizer.java | 4 +- .../embedding/EmbeddingRecallRecognizer.java | 4 +- .../chat/server/pojo/ExecuteContext.java | 12 +++--- .../chat/server/pojo/ParseContext.java | 14 +++---- .../execute/DataInterpretProcessor.java | 2 +- .../execute/MetricRatioProcessor.java | 4 +- .../parse/QueryRecommendProcessor.java | 4 +- .../service/impl/ChatQueryServiceImpl.java | 18 +++------ .../chat/server/util/QueryReqConverter.java | 36 ++++------------- .../service/impl/ExemplarServiceImpl.java | 4 +- 14 files changed, 66 insertions(+), 96 deletions(-) diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java index b5e3fab9b..6980425dd 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.chat.api.pojo.request; import com.tencent.supersonic.common.pojo.User; -import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import lombok.AllArgsConstructor; import lombok.Builder; @@ -16,10 +15,8 @@ public class ChatParseReq { private String queryText; private Integer chatId; private Integer agentId; - private Integer topN = 10; private User user; private QueryFilters queryFilters; private boolean saveAnswer = true; - private SchemaMapInfo mapInfo = new SchemaMapInfo(); private boolean disableLLM = false; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java index ae2204e9f..42c7c30d1 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -50,7 +50,7 @@ public class PlainTextExecutor implements ChatQueryExecutor { } String promptStr = String.format(chatApp.getPrompt(), getHistoryInputs(executeContext), - executeContext.getQueryText()); + executeContext.getRequest().getQueryText()); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatApp.getChatModelConfig()); @@ -66,8 +66,8 @@ public class PlainTextExecutor implements ChatQueryExecutor { private String getHistoryInputs(ExecuteContext executeContext) { StringBuilder historyInput = new StringBuilder(); - List queryResps = getHistoryQueries(executeContext.getChatId(), 5); - queryResps.stream().forEach(p -> { + List queryResps = getHistoryQueries(executeContext.getRequest().getChatId(), 5); + queryResps.forEach(p -> { historyInput.append(p.getQueryText()); historyInput.append(";"); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java index bbf481bef..a77d8c03f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java @@ -48,9 +48,9 @@ public class SqlExecutor implements ChatQueryExecutor { .agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING) .question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo()) .dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql()) - .createdBy(executeContext.getUser().getName()) - .updatedBy(executeContext.getUser().getName()).createdAt(new Date()) - .build()); + .createdBy(executeContext.getRequest().getUser().getName()) + .updatedBy(executeContext.getRequest().getUser().getName()) + .createdAt(new Date()).build()); } } @@ -62,7 +62,8 @@ public class SqlExecutor implements ChatQueryExecutor { SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class); ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); - ChatContext chatCtx = chatContextService.getOrCreateContext(executeContext.getChatId()); + ChatContext chatCtx = + chatContextService.getOrCreateContext(executeContext.getRequest().getChatId()); SemanticParseInfo parseInfo = executeContext.getParseInfo(); if (Objects.isNull(parseInfo.getSqlInfo()) || StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) { @@ -79,7 +80,8 @@ public class SqlExecutor implements ChatQueryExecutor { queryResult.setChatContext(parseInfo); queryResult.setQueryMode(parseInfo.getQueryMode()); queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime); - SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, executeContext.getUser()); + SemanticQueryResp queryResp = + semanticLayer.queryByReq(sqlReq, executeContext.getRequest().getUser()); if (queryResp != null) { queryResult.setQueryAuthorization(queryResp.getQueryAuthorization()); queryResult.setQuerySql(queryResp.getSql()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 15c6998f5..286df8002 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -91,16 +91,24 @@ public class NL2SQLParser implements ChatQueryParser { @Override public void parse(ParseContext parseContext, ParseResp parseResp) { - if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) { + if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) { + return; + } + QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); + if (Objects.isNull(queryNLReq)) { return; } - ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); - ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId()); - if (!parseContext.isDisableLLM()) { + if (!parseContext.getRequest().isDisableLLM()) { processMultiTurn(parseContext); } - QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext, chatCtx); + + ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); + ChatContext chatCtx = + chatContextService.getOrCreateContext(parseContext.getRequest().getChatId()); + if (chatCtx != null) { + queryNLReq.setContextParseInfo(chatCtx.getParseInfo()); + } addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq); ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); @@ -108,7 +116,7 @@ public class NL2SQLParser implements ChatQueryParser { if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) { parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); } else { - if (!parseContext.isDisableLLM()) { + if (!parseContext.getRequest().isDisableLLM()) { parseResp.setErrorMsg(rewriteErrorMessage(parseContext, text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars())); } @@ -119,16 +127,6 @@ public class NL2SQLParser implements ChatQueryParser { formatParseResult(parseResp); } - private boolean checkSkip(ParseResp parseResp) { - List selectedParses = parseResp.getSelectedParses(); - for (SemanticParseInfo semanticParseInfo : selectedParses) { - if (semanticParseInfo.getScore() >= parseResp.getQueryText().length()) { - return true; - } - } - return false; - } - private void formatParseResult(ParseResp parseResp) { List selectedParses = parseResp.getSelectedParses(); for (SemanticParseInfo parseInfo : selectedParses) { @@ -182,7 +180,8 @@ public class NL2SQLParser implements ChatQueryParser { QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); MapResp currentMapResult = chatLayerService.map(queryNLReq); - List historyQueries = getHistoryQueries(parseContext.getChatId(), 1); + List historyQueries = + getHistoryQueries(parseContext.getRequest().getChatId(), 1); if (historyQueries.isEmpty()) { return; } @@ -208,7 +207,7 @@ public class NL2SQLParser implements ChatQueryParser { Response response = chatLanguageModel.generate(prompt.toUserMessage()); String rewrittenQuery = response.content().text(); keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response); - parseContext.setQueryText(rewrittenQuery); + parseContext.getRequest().setQueryText(rewrittenQuery); log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery); } @@ -222,7 +221,7 @@ public class NL2SQLParser implements ChatQueryParser { } Map variables = new HashMap<>(); - variables.put("user_question", parseContext.getQueryText()); + variables.put("user_question", parseContext.getRequest().getQueryText()); variables.put("system_message", errMsg); StringBuilder exampleStr = new StringBuilder(); @@ -286,7 +285,7 @@ public class NL2SQLParser implements ChatQueryParser { String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId); ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); int exemplarRecallNumber = - Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER)); + Integer.parseInt(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER)); List exemplars = exemplarManager.recallExemplars(memoryCollectionName, queryNLReq.getQueryText(), exemplarRecallNumber); queryNLReq.getDynamicExemplars().addAll(exemplars); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java index 86db67436..9a26eeb88 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java @@ -72,7 +72,7 @@ public abstract class PluginRecognizer { protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin, ParseContext parseContext, SchemaMapInfo mapInfo, double distance) { List schemaElementMatches = mapInfo.getMatchedElements(dataSetId); - QueryFilters queryFilters = parseContext.getQueryFilters(); + QueryFilters queryFilters = parseContext.getRequest().getQueryFilters(); if (schemaElementMatches == null) { schemaElementMatches = Lists.newArrayList(); } @@ -86,7 +86,7 @@ public abstract class PluginRecognizer { pluginParseResult.setPlugin(plugin); pluginParseResult.setQueryFilters(queryFilters); pluginParseResult.setDistance(distance); - pluginParseResult.setQueryText(parseContext.getQueryText()); + pluginParseResult.setQueryText(parseContext.getRequest().getQueryText()); properties.put(Constants.CONTEXT, pluginParseResult); properties.put("type", "plugin"); properties.put("name", plugin.getName()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java index 589192c45..715d5b82e 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java @@ -30,7 +30,7 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer { } public PluginRecallResult recallPlugin(ParseContext parseContext) { - String text = parseContext.getQueryText(); + String text = parseContext.getRequest().getQueryText(); List embeddingRetrievals = embeddingRecall(text); if (CollectionUtils.isEmpty(embeddingRetrievals)) { return null; @@ -52,7 +52,7 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer { } plugin.setParseMode(ParseMode.EMBEDDING_RECALL); double similarity = embeddingRetrieval.getSimilarity(); - double score = parseContext.getQueryText().length() * similarity; + double score = parseContext.getRequest().getQueryText().length() * similarity; return PluginRecallResult.builder().plugin(plugin).dataSetIds(dataSetList) .score(score).distance(similarity).build(); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ExecuteContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ExecuteContext.java index 258cc7ca6..acddec632 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ExecuteContext.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ExecuteContext.java @@ -1,17 +1,17 @@ package com.tencent.supersonic.chat.server.pojo; +import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; import com.tencent.supersonic.chat.server.agent.Agent; -import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import lombok.Data; @Data public class ExecuteContext { - private User user; - private String queryText; + private ChatExecuteReq request; private Agent agent; - private Integer chatId; - private Long queryId; - private boolean saveAnswer; private SemanticParseInfo parseInfo; + + public ExecuteContext(ChatExecuteReq request) { + this.request = request; + } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java index c6bf84011..5b1c76f34 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java @@ -1,19 +1,17 @@ package com.tencent.supersonic.chat.server.pojo; +import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.agent.Agent; -import com.tencent.supersonic.common.pojo.User; -import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import lombok.Data; @Data public class ParseContext { - private User user; - private String queryText; + private ChatParseReq request; private Agent agent; - private Integer chatId; - private QueryFilters queryFilters; - private boolean saveAnswer = true; - private boolean disableLLM = false; + + public ParseContext(ChatParseReq request) { + this.request = request; + } public boolean enableNL2SQL() { if (agent == null) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DataInterpretProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DataInterpretProcessor.java index 2f7b73266..34dc9fb01 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DataInterpretProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DataInterpretProcessor.java @@ -48,7 +48,7 @@ public class DataInterpretProcessor implements ExecuteResultProcessor { } Map variable = new HashMap<>(); - variable.put("question", executeContext.getQueryText()); + variable.put("question", executeContext.getRequest().getQueryText()); variable.put("data", queryResult.getTextResult()); Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRatioProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRatioProcessor.java index 1d9544154..c456d49f0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRatioProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRatioProcessor.java @@ -67,8 +67,8 @@ public class MetricRatioProcessor implements ExecuteResultProcessor { || !QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())) { return; } - AggregateInfo aggregateInfo = - getAggregateInfo(executeContext.getUser(), semanticParseInfo, queryResult); + AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getRequest().getUser(), + semanticParseInfo, queryResult); queryResult.setAggregateInfo(aggregateInfo); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java index 96dfbb215..d977cd673 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java @@ -30,8 +30,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor { @SneakyThrows private void doProcess(ParseResp parseResp, ParseContext parseContext) { Long queryId = parseResp.getQueryId(); - List solvedQueries = - getSimilarQueries(parseContext.getQueryText(), parseContext.getAgent().getId()); + List solvedQueries = getSimilarQueries( + parseContext.getRequest().getQueryText(), parseContext.getAgent().getId()); ChatQueryDO chatQueryDO = getChatQuery(queryId); chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries)); updateChatQuery(chatQueryDO); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index 401c581a0..772969944 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -24,8 +24,6 @@ import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; -import com.tencent.supersonic.common.service.ChatModelService; -import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.JsonUtil; @@ -86,8 +84,6 @@ public class ChatQueryServiceImpl implements ChatQueryService { private SemanticLayerService semanticLayerService; @Autowired private AgentService agentService; - @Autowired - private ChatModelService chatModelService; private final List chatQueryParsers = ComponentFactory.getChatParsers(); private final List chatQueryExecutors = ComponentFactory.getChatExecutors(); @@ -120,7 +116,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { processor.process(parseContext, parseResp); } - chatParseReq.setQueryText(parseContext.getQueryText()); + chatParseReq.setQueryText(parseContext.getRequest().getQueryText()); chatManageService.batchAddParse(chatParseReq, parseResp); chatManageService.updateParseCostTime(parseResp); return parseResp; @@ -168,16 +164,14 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private ParseContext buildParseContext(ChatParseReq chatParseReq) { - ParseContext parseContext = new ParseContext(); - BeanMapper.mapper(chatParseReq, parseContext); + ParseContext parseContext = new ParseContext(chatParseReq); Agent agent = agentService.getAgent(chatParseReq.getAgentId()); parseContext.setAgent(agent); return parseContext; } private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) { - ExecuteContext executeContext = new ExecuteContext(); - BeanMapper.mapper(chatExecuteReq, executeContext); + ExecuteContext executeContext = new ExecuteContext(chatExecuteReq); SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatExecuteReq.getQueryId(), chatExecuteReq.getParseId()); Agent agent = agentService.getAgent(chatExecuteReq.getAgentId()); @@ -443,14 +437,14 @@ public class ChatQueryServiceImpl implements ChatQueryService { if (CollectionUtils.isEmpty(valueList)) { return; } - valueList.stream().forEach(o -> { + valueList.forEach(o -> { StringValue stringValue = new StringValue(o); parenthesedExpressionList.add(stringValue); }); inExpression.setLeftExpression(column); inExpression.setRightExpression(parenthesedExpressionList); addConditions.add(inExpression); - contextMetricFilters.stream().forEach(o -> { + contextMetricFilters.forEach(o -> { if (o.getName().equals(dslQueryFilter.getName())) { o.setValue(dslQueryFilter.getValue()); o.setOperator(dslQueryFilter.getOperator()); @@ -480,7 +474,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { comparisonExpression.setRightExpression(stringValue); } addConditions.add(comparisonExpression); - contextMetricFilters.stream().forEach(o -> { + contextMetricFilters.forEach(o -> { if (o.getName().equals(dslQueryFilter.getName())) { o.setValue(dslQueryFilter.getValue()); o.setOperator(dslQueryFilter.getOperator()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index c75871f83..dfa91a667 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -1,44 +1,24 @@ package com.tencent.supersonic.chat.server.util; -import com.tencent.supersonic.chat.server.agent.Agent; -import com.tencent.supersonic.chat.server.pojo.ChatContext; import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; -import org.apache.commons.collections.MapUtils; - -import java.util.Objects; public class QueryReqConverter { public static QueryNLReq buildQueryNLReq(ParseContext parseContext) { - return buildQueryNLReq(parseContext, null); - } + if (parseContext.getAgent() == null) { + return null; + } - public static QueryNLReq buildQueryNLReq(ParseContext parseContext, ChatContext chatCtx) { QueryNLReq queryNLReq = new QueryNLReq(); - BeanMapper.mapper(parseContext, queryNLReq); - Agent agent = parseContext.getAgent(); - if (agent == null) { - return queryNLReq; - } - - if (parseContext.isDisableLLM()) { - queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); - } else { - queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM); - } - - queryNLReq.setDataSetIds(agent.getDataSetIds()); - if (Objects.nonNull(queryNLReq.getMapInfo()) - && MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) { - queryNLReq.setMapInfo(queryNLReq.getMapInfo()); - } + BeanMapper.mapper(parseContext.getRequest(), queryNLReq); + queryNLReq.setText2SQLType(parseContext.getRequest().isDisableLLM() ? Text2SQLType.ONLY_RULE + : Text2SQLType.RULE_AND_LLM); + queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds()); queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig()); - if (chatCtx != null) { - queryNLReq.setContextParseInfo(chatCtx.getParseInfo()); - } + return queryNLReq; } } diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java index 7819fa8e1..8852375e0 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java @@ -70,8 +70,8 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { RetrieveQuery.builder().queryTextsList(Lists.newArrayList(query)).build(); List results = embeddingService.retrieveQuery(collection, retrieveQuery, num); - results.stream().forEach(ret -> { - ret.getRetrieval().stream().forEach(r -> { + results.forEach(ret -> { + ret.getRetrieval().forEach(r -> { exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class)); }); });