From cd972d0850e2317423d34a544424737912849c62 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Thu, 9 Nov 2023 14:21:18 +0800 Subject: [PATCH] [improvement][chat] Remove deprecated executeQuery method --- .../api/pojo/request/ExecuteQueryReq.java | 2 + .../chat/rest/ChatQueryController.java | 7 --- .../supersonic/chat/service/QueryService.java | 2 - .../chat/service/impl/QueryServiceImpl.java | 49 ++----------------- .../integration/MetricInterpretTest.java | 12 ++++- .../plugin/PluginRecognizeTest.java | 23 ++++++++- 6 files changed, 38 insertions(+), 57 deletions(-) diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ExecuteQueryReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ExecuteQueryReq.java index d57bd58f0..0f0b5ced6 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ExecuteQueryReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ExecuteQueryReq.java @@ -3,8 +3,10 @@ package com.tencent.supersonic.chat.api.pojo.request; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import lombok.Builder; import lombok.Data; +@Builder @Data public class ExecuteQueryReq { private User user; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java index a0bff539f..151a11cf9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java @@ -41,13 +41,6 @@ public class ChatQueryController { return searchService.search(queryCtx); } - @PostMapping("query") - public Object query(@RequestBody QueryReq queryCtx, HttpServletRequest request, HttpServletResponse response) - throws Exception { - queryCtx.setUser(UserHolder.findUser(request, response)); - return queryService.executeQuery(queryCtx); - } - @PostMapping("parse") public Object parse(@RequestBody QueryReq queryCtx, HttpServletRequest request, HttpServletResponse response) throws Exception { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/QueryService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/QueryService.java index 81296edf9..cdefed69f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/QueryService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/QueryService.java @@ -20,8 +20,6 @@ public interface QueryService { QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception; - QueryResult executeQuery(QueryReq queryReq) throws Exception; - SemanticParseInfo queryContext(QueryReq queryReq); QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 63aee2288..1fadbba58 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -254,7 +254,7 @@ public class QueryServiceImpl implements QueryService { } // save time cost data - public void saveInfo(List timeCostDOList, + private void saveInfo(List timeCostDOList, String queryText, Long queryId, String userName, Long chatId) { List list = timeCostDOList.stream() @@ -284,47 +284,6 @@ public class QueryServiceImpl implements QueryService { .queryText(queryReq.getQueryText()).build()); } - @Override - public QueryResult executeQuery(QueryReq queryReq) throws Exception { - QueryContext queryCtx = new QueryContext(queryReq); - // in order to support multi-turn conversation, chat context is needed - ChatContext chatCtx = chatService.getOrCreateContext(queryReq.getChatId()); - - schemaMappers.stream().forEach(mapper -> { - mapper.map(queryCtx); - log.info("{} result:{}", mapper.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); - }); - - semanticParsers.stream().forEach(parser -> { - parser.parse(queryCtx, chatCtx); - log.info("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); - }); - - QueryResult queryResult = null; - if (queryCtx.getCandidateQueries().size() > 0) { - log.info("pick before [{}]", queryCtx.getCandidateQueries().stream().collect( - Collectors.toList())); - List selectedQueries = querySelector.select(queryCtx.getCandidateQueries(), queryReq); - log.info("pick after [{}]", selectedQueries.stream().collect( - Collectors.toList())); - - SemanticQuery semanticQuery = selectedQueries.get(0); - queryResult = semanticQuery.execute(queryReq.getUser()); - if (queryResult != null) { - chatCtx.setQueryText(queryReq.getQueryText()); - // update chat context after a successful semantic query - if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) { - chatCtx.setParseInfo(semanticQuery.getParseInfo()); - chatService.updateContext(chatCtx); - } - queryResult.setChatContext(chatCtx.getParseInfo()); - chatService.addQuery(queryResult, chatCtx); - } - } - - return queryResult; - } - @Override public SemanticParseInfo queryContext(QueryReq queryCtx) { ChatContext context = chatService.getOrCreateContext(queryCtx.getChatId()); @@ -463,7 +422,7 @@ public class QueryServiceImpl implements QueryService { parseInfo.setDateInfo(queryData.getDateInfo()); } - public void addTimeFilters(String date, + private void addTimeFilters(String date, T comparisonExpression, List addConditions) { Column column = new Column(TimeDimensionEnum.DAY.getChName()); @@ -513,7 +472,7 @@ public class QueryServiceImpl implements QueryService { } // add in condition to sql where condition - public void addWhereInFilters(QueryFilter dslQueryFilter, + private void addWhereInFilters(QueryFilter dslQueryFilter, InExpression inExpression, Set contextMetricFilters, List addConditions) { @@ -542,7 +501,7 @@ public class QueryServiceImpl implements QueryService { } // add where filter - public void addWhereFilters(QueryFilter dslQueryFilter, + private void addWhereFilters(QueryFilter dslQueryFilter, T comparisonExpression, Set contextMetricFilters, List addConditions) { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java index 446274f07..eb7f0a9d8 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java @@ -2,7 +2,9 @@ package com.tencent.supersonic.integration; import com.alibaba.fastjson.JSONObject; import com.tencent.supersonic.StandaloneLauncher; +import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.chat.plugin.PluginManager; @@ -49,7 +51,15 @@ public class MetricInterpretTest { ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp))); QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况", DataUtils.getAgent().getId()); - QueryResult queryResult = queryService.executeQuery(queryReq); + + ParseResp parseResp = queryService.performParsing(queryReq); + ExecuteQueryReq executeReq = ExecuteQueryReq.builder().user(queryReq.getUser()) + .chatId(parseResp.getChatId()) + .queryId(parseResp.getQueryId()) + .queryText(parseResp.getQueryText()) + .parseInfo(parseResp.getSelectedParses().get(0)) + .build(); + QueryResult queryResult = queryService.performExecution(executeReq); Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistantMessage()); } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java index b7e81836a..ac8a577ea 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.integration.plugin; +import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; @@ -37,7 +38,16 @@ public class PluginRecognizeTest extends BasePluginTest { MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样", "1"); MockConfiguration.mockEmbeddingUrl(embeddingConfig); QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(1000, "alice最近的访问情况怎么样", 1); - QueryResult queryResult = queryService.executeQuery(queryContextReq); + + ParseResp parseResp = queryService.performParsing(queryContextReq); + ExecuteQueryReq executeReq = ExecuteQueryReq.builder().user(queryContextReq.getUser()) + .chatId(parseResp.getChatId()) + .queryId(parseResp.getQueryId()) + .queryText(parseResp.getQueryText()) + .parseInfo(parseResp.getSelectedParses().get(0)) + .build(); + QueryResult queryResult = queryService.performExecution(executeReq); + assertPluginRecognizeResult(queryResult); } @@ -53,7 +63,16 @@ public class PluginRecognizeTest extends BasePluginTest { queryRequest.setModelId(1L); queryFilters.getFilters().add(queryFilter); queryRequest.setQueryFilters(queryFilters); - QueryResult queryResult = queryService.executeQuery(queryRequest); + + ParseResp parseResp = queryService.performParsing(queryRequest); + ExecuteQueryReq executeReq = ExecuteQueryReq.builder().user(queryRequest.getUser()) + .chatId(parseResp.getChatId()) + .queryId(parseResp.getQueryId()) + .queryText(parseResp.getQueryText()) + .parseInfo(parseResp.getSelectedParses().get(0)) + .build(); + QueryResult queryResult = queryService.performExecution(executeReq); + assertPluginRecognizeResult(queryResult); }