From 3e0f724e9728a0782daee3ecc54008b41a05c042 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Sun, 27 Oct 2024 17:44:29 +0800 Subject: [PATCH] [improvement][chat]Incorporate `Response` into `Context` objects. --- .../supersonic/chat/server/parser/ChatQueryParser.java | 3 +-- .../supersonic/chat/server/parser/NL2PluginParser.java | 7 +++---- .../supersonic/chat/server/parser/NL2SQLParser.java | 10 ++++++---- .../supersonic/chat/server/parser/PlainTextParser.java | 6 +++--- .../chat/server/plugin/recognize/PluginRecognizer.java | 4 ++-- .../supersonic/chat/server/pojo/ParseContext.java | 3 +++ .../server/processor/parse/ParseResultProcessor.java | 3 +-- .../processor/parse/QueryRecommendProcessor.java | 9 ++++----- .../chat/server/processor/parse/TimeCostProcessor.java | 3 ++- .../chat/server/service/ChatManageService.java | 2 +- .../server/service/impl/ChatManageServiceImpl.java | 5 ++--- .../chat/server/service/impl/ChatQueryServiceImpl.java | 9 +++++---- .../server/service/impl/RetrieveServiceImpl.java | 6 +++--- 13 files changed, 36 insertions(+), 34 deletions(-) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ChatQueryParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ChatQueryParser.java index f420f7229..33410a3f7 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ChatQueryParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ChatQueryParser.java @@ -1,9 +1,8 @@ package com.tencent.supersonic.chat.server.parser; import com.tencent.supersonic.chat.server.pojo.ParseContext; -import com.tencent.supersonic.headless.api.pojo.response.ParseResp; public interface ChatQueryParser { - void parse(ParseContext parseContext, ParseResp parseResp); + void parse(ParseContext parseContext); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java index 1082b56e1..c48f05da1 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java @@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer; import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.util.ComponentFactory; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import lombok.extern.slf4j.Slf4j; import java.util.List; @@ -16,15 +15,15 @@ public class NL2PluginParser implements ChatQueryParser { ComponentFactory.getPluginRecognizers(); @Override - public void parse(ParseContext parseContext, ParseResp parseResp) { + public void parse(ParseContext parseContext) { if (!parseContext.getAgent().containsPluginTool()) { return; } pluginRecognizers.forEach(pluginRecognizer -> { - pluginRecognizer.recognize(parseContext, parseResp); + pluginRecognizer.recognize(parseContext); log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(), - JsonUtil.toString(parseResp)); + JsonUtil.toString(parseContext.getResponse())); }); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 286df8002..19a67a2b4 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.chat.server.parser; +import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.pojo.ChatContext; @@ -90,7 +91,7 @@ public class NL2SQLParser implements ChatQueryParser { } @Override - public void parse(ParseContext parseContext, ParseResp parseResp) { + public void parse(ParseContext parseContext) { if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) { return; } @@ -98,14 +99,15 @@ public class NL2SQLParser implements ChatQueryParser { if (Objects.isNull(queryNLReq)) { return; } + ParseResp parseResp = parseContext.getResponse(); + ChatParseReq parseReq = parseContext.getRequest(); if (!parseContext.getRequest().isDisableLLM()) { processMultiTurn(parseContext); } ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); - ChatContext chatCtx = - chatContextService.getOrCreateContext(parseContext.getRequest().getChatId()); + ChatContext chatCtx = chatContextService.getOrCreateContext(parseReq.getChatId()); if (chatCtx != null) { queryNLReq.setContextParseInfo(chatCtx.getParseInfo()); } @@ -116,7 +118,7 @@ public class NL2SQLParser implements ChatQueryParser { if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) { parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); } else { - if (!parseContext.getRequest().isDisableLLM()) { + if (!parseReq.isDisableLLM()) { parseResp.setErrorMsg(rewriteErrorMessage(parseContext, text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars())); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java index a73f66ebc..154936087 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java @@ -7,14 +7,14 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp; public class PlainTextParser implements ChatQueryParser { @Override - public void parse(ParseContext parseContext, ParseResp parseResp) { + public void parse(ParseContext parseContext) { if (parseContext.getAgent().containsAnyTool()) { return; } SemanticParseInfo parseInfo = new SemanticParseInfo(); parseInfo.setQueryMode("PLAIN_TEXT"); - parseResp.getSelectedParses().add(parseInfo); - parseResp.setState(ParseResp.ParseState.COMPLETED); + parseContext.getResponse().getSelectedParses().add(parseInfo); + parseContext.getResponse().setState(ParseResp.ParseState.COMPLETED); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java index 9a26eeb88..0c8f5e864 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java @@ -31,7 +31,7 @@ import java.util.Set; /** PluginParser defines the basic process and common methods for recalling plugins. */ public abstract class PluginRecognizer { - public void recognize(ParseContext parseContext, ParseResp parseResp) { + public void recognize(ParseContext parseContext) { if (!checkPreCondition(parseContext)) { return; } @@ -39,7 +39,7 @@ public abstract class PluginRecognizer { if (pluginRecallResult == null) { return; } - buildQuery(parseContext, parseResp, pluginRecallResult); + buildQuery(parseContext, parseContext.getResponse(), pluginRecallResult); } public abstract boolean checkPreCondition(ParseContext parseContext); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java index 5b1c76f34..1eac60017 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java @@ -2,15 +2,18 @@ package com.tencent.supersonic.chat.server.pojo; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import lombok.Data; @Data public class ParseContext { private ChatParseReq request; + private ParseResp response; private Agent agent; public ParseContext(ChatParseReq request) { this.request = request; + response = new ParseResp(request.getQueryText()); } public boolean enableNL2SQL() { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseResultProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseResultProcessor.java index 0f5fb864b..9ccd2849d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseResultProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseResultProcessor.java @@ -2,10 +2,9 @@ package com.tencent.supersonic.chat.server.processor.parse; import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.processor.ResultProcessor; -import com.tencent.supersonic.headless.api.pojo.response.ParseResp; /** A ParseResultProcessor wraps things up before returning parsing results to the users. */ public interface ParseResultProcessor extends ResultProcessor { - void process(ParseContext parseContext, ParseResp parseResp); + void process(ParseContext parseContext); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java index d977cd673..a302e1bab 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java @@ -10,7 +10,6 @@ import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; @@ -23,13 +22,13 @@ import java.util.stream.Collectors; public class QueryRecommendProcessor implements ParseResultProcessor { @Override - public void process(ParseContext parseContext, ParseResp parseResp) { - CompletableFuture.runAsync(() -> doProcess(parseResp, parseContext)); + public void process(ParseContext parseContext) { + CompletableFuture.runAsync(() -> doProcess(parseContext)); } @SneakyThrows - private void doProcess(ParseResp parseResp, ParseContext parseContext) { - Long queryId = parseResp.getQueryId(); + private void doProcess(ParseContext parseContext) { + Long queryId = parseContext.getResponse().getQueryId(); List solvedQueries = getSimilarQueries( parseContext.getRequest().getQueryText(), parseContext.getAgent().getId()); ChatQueryDO chatQueryDO = getChatQuery(queryId); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/TimeCostProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/TimeCostProcessor.java index 6c2e346f4..298664a07 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/TimeCostProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/TimeCostProcessor.java @@ -9,7 +9,8 @@ import lombok.extern.slf4j.Slf4j; public class TimeCostProcessor implements ParseResultProcessor { @Override - public void process(ParseContext parseContext, ParseResp parseResp) { + public void process(ParseContext parseContext) { + ParseResp parseResp = parseContext.getResponse(); long parseStartTime = parseResp.getParseTimeCost().getParseStartTime(); parseResp.getParseTimeCost().setParseTime(System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatManageService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatManageService.java index 98c89e115..c389e3bb9 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatManageService.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatManageService.java @@ -31,7 +31,7 @@ public interface ChatManageService { PageInfo queryInfo(PageQueryInfoReq pageQueryInfoReq, long chatId); - void createChatQuery(ChatParseReq chatParseReq, ParseResp parseResp); + Long createChatQuery(ChatParseReq chatParseReq); QueryResp getChatQuery(Long queryId); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatManageServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatManageServiceImpl.java index b649f9cfe..29f56fe49 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatManageServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatManageServiceImpl.java @@ -93,9 +93,8 @@ public class ChatManageServiceImpl implements ChatManageService { } @Override - public void createChatQuery(ChatParseReq chatParseReq, ParseResp parseResp) { - Long queryId = chatQueryRepository.createChatQuery(chatParseReq); - parseResp.setQueryId(queryId); + public Long createChatQuery(ChatParseReq chatParseReq) { + return chatQueryRepository.createChatQuery(chatParseReq); } @Override diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index 772969944..8c786470f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -105,15 +105,16 @@ public class ChatQueryServiceImpl implements ChatQueryService { @Override public ParseResp parse(ChatParseReq chatParseReq) { - ParseResp parseResp = new ParseResp(chatParseReq.getQueryText()); - chatManageService.createChatQuery(chatParseReq, parseResp); ParseContext parseContext = buildParseContext(chatParseReq); + ParseResp parseResp = parseContext.getResponse(); + Long queryId = chatManageService.createChatQuery(chatParseReq); + parseResp.setQueryId(queryId); for (ChatQueryParser parser : chatQueryParsers) { - parser.parse(parseContext, parseResp); + parser.parse(parseContext); } for (ParseResultProcessor processor : parseResultProcessors) { - processor.process(parseContext, parseResp); + processor.process(parseContext); } chatParseReq.setQueryText(parseContext.getRequest().getQueryText()); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/RetrieveServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/RetrieveServiceImpl.java index a9ef8ebac..7dec0670f 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/RetrieveServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/RetrieveServiceImpl.java @@ -120,7 +120,7 @@ public class RetrieveServiceImpl implements RetrieveService { return searchResults.stream().limit(RESULT_SIZE).collect(Collectors.toList()); } - private List getPossibleDataSets(QueryNLReq queryCtx, List originals, + private List getPossibleDataSets(QueryNLReq queryReq, List originals, Set dataSetIds) { if (CollectionUtils.isNotEmpty(dataSetIds)) { return new ArrayList<>(dataSetIds); @@ -128,8 +128,8 @@ public class RetrieveServiceImpl implements RetrieveService { List possibleDataSets = NatureHelper.selectPossibleDataSets(originals); if (possibleDataSets.isEmpty()) { - if (Objects.nonNull(queryCtx.getContextParseInfo())) { - possibleDataSets.add(queryCtx.getContextParseInfo().getDataSetId()); + if (Objects.nonNull(queryReq.getContextParseInfo())) { + possibleDataSets.add(queryReq.getContextParseInfo().getDataSetId()); } }