From 731238de08d323d111771dea183e9eb9eae1ef83 Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Sun, 12 Nov 2023 22:47:58 +0800 Subject: [PATCH] Feature/Refactor querySelect to queryRanker and fix some errors in integration tests (#369) * (fix) (chat) fix the context saving failure caused by the loss of default values caused by @builder * (fix) (chat) fix date and metrics result in parse info in integration test * (improvement) (chat) refactor querySelect to queryRanker --------- Co-authored-by: jolunoluo --- .../api/pojo/request/ExecuteQueryReq.java | 6 +- .../chat/api/pojo/response/ParseResp.java | 12 ++- .../repository/ChatQueryRepository.java | 3 +- .../impl/ChatQueryRepositoryImpl.java | 14 ++-- .../chat/query/HeuristicQuerySelector.java | 84 ------------------- .../supersonic/chat/query/QueryRanker.java | 64 ++++++++++++++ .../supersonic/chat/query/QuerySelector.java | 14 ---- .../parse/EntityInfoParseResponder.java | 2 +- .../parse/SqlInfoParseResponder.java | 2 - .../chat/service/ParseInfoService.java | 7 -- .../chat/service/impl/ChatServiceImpl.java | 3 +- .../service/impl/ParserInfoServiceImpl.java | 53 +++--------- .../chat/service/impl/QueryServiceImpl.java | 23 ++--- .../chat/utils/ComponentFactory.java | 8 -- .../main/resources/META-INF/spring.factories | 3 - .../com/tencent/supersonic/ConfigureDemo.java | 5 +- .../main/resources/META-INF/spring.factories | 3 - .../supersonic/integration/BaseQueryTest.java | 10 ++- .../integration/MetricInterpretTest.java | 4 +- .../integration/MetricQueryTest.java | 8 +- .../plugin/PluginRecognizeTest.java | 8 +- .../test/resources/META-INF/spring.factories | 3 - .../listener/MetaEmbeddingListener.java | 2 + 23 files changed, 127 insertions(+), 214 deletions(-) delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/HeuristicQuerySelector.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryRanker.java delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/QuerySelector.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ExecuteQueryReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ExecuteQueryReq.java index 0f0b5ced6..9cc96826a 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ExecuteQueryReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ExecuteQueryReq.java @@ -13,8 +13,8 @@ public class ExecuteQueryReq { private Integer agentId; private Integer chatId; private String queryText; - private Long queryId = 7L; - private Integer parseId = 2; + private Long queryId; + private Integer parseId; private SemanticParseInfo parseInfo; - private boolean saveAnswer = true; + private boolean saveAnswer; } diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java index 8e2c476ff..36107268c 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.chat.api.pojo.response; +import com.google.common.collect.Lists; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import lombok.Data; import lombok.Getter; @@ -19,8 +20,8 @@ public class ParseResp { private String queryText; private Long queryId; private ParseState state; - private List selectedParses; - private List candidateParses; + private List selectedParses = Lists.newArrayList(); + private List candidateParses = Lists.newArrayList(); private List similarSolvedQuery; private ParseTimeCostDO parseTimeCost; @@ -29,4 +30,11 @@ public class ParseResp { PENDING, FAILED } + + public List getSelectedParses() { + selectedParses = Lists.newArrayList(); + selectedParses.addAll(candidateParses); + candidateParses.clear(); + return selectedParses; + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/ChatQueryRepository.java b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/ChatQueryRepository.java index 563d912cc..8dc8cf33d 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/ChatQueryRepository.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/ChatQueryRepository.java @@ -31,8 +31,7 @@ public interface ChatQueryRepository { List batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult, - List candidateParses, - List selectedParses); + List candidateParses); public ChatParseDO getParseInfo(Long questionId, int parseId); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatQueryRepositoryImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatQueryRepositoryImpl.java index 5bb874d1e..5146901b3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatQueryRepositoryImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatQueryRepositoryImpl.java @@ -133,13 +133,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { @Override public List batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult, - List candidateParses, - List selectedParses) { + List candidateParses) { Long queryId = createChatParse(parseResult, chatCtx, queryReq); List chatParseDOList = new ArrayList<>(); - log.info("candidateParses size:{},selectedParses size:{}", candidateParses.size(), selectedParses.size()); - getChatParseDO(chatCtx, queryReq, queryId, 0, 1, candidateParses, chatParseDOList); - getChatParseDO(chatCtx, queryReq, queryId, candidateParses.size(), 0, selectedParses, chatParseDOList); + getChatParseDO(chatCtx, queryReq, queryId, 0, candidateParses, chatParseDOList); chatParseMapper.batchSaveParseInfo(chatParseDOList); return chatParseDOList; } @@ -151,7 +148,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { } } - public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base, int isCandidate, + public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base, List parses, List chatParseDOList) { for (int i = 0; i < parses.size(); i++) { ChatParseDO chatParseDO = new ChatParseDO(); @@ -160,7 +157,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { chatParseDO.setQuestionId(queryId); chatParseDO.setQueryText(queryReq.getQueryText()); chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i))); - chatParseDO.setIsCandidate(isCandidate); + chatParseDO.setIsCandidate(1); + if (i == 0) { + chatParseDO.setIsCandidate(0); + } chatParseDO.setParseId(base + i + 1); chatParseDO.setCreateTime(new java.util.Date()); chatParseDO.setUserName(queryReq.getUser().getName()); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/HeuristicQuerySelector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/HeuristicQuerySelector.java deleted file mode 100644 index 0b75e2709..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/HeuristicQuerySelector.java +++ /dev/null @@ -1,84 +0,0 @@ -package com.tencent.supersonic.chat.query; - -import com.tencent.supersonic.chat.api.component.SemanticQuery; -import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.request.QueryReq; -import com.tencent.supersonic.chat.config.OptimizationConfig; -import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; - -import java.util.List; -import java.util.ArrayList; -import java.util.OptionalDouble; - -import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery; -import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery; -import com.tencent.supersonic.common.util.ContextUtils; -import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections.CollectionUtils; - -@Slf4j -public class HeuristicQuerySelector implements QuerySelector { - - @Override - public List select(List candidateQueries, QueryReq queryReq) { - log.debug("pick before [{}]", candidateQueries.stream().collect(Collectors.toList())); - List selectedQueries = new ArrayList<>(); - OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); - Double candidateThreshold = optimizationConfig.getCandidateThreshold(); - if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) { - selectedQueries.addAll(candidateQueries); - } else { - OptionalDouble maxScoreOp = candidateQueries.stream().mapToDouble( - q -> q.getParseInfo().getScore()).max(); - if (maxScoreOp.isPresent()) { - double maxScore = maxScoreOp.getAsDouble(); - - candidateQueries.stream().forEach(query -> { - SemanticParseInfo parseInfo = query.getParseInfo(); - if (!checkFullyInherited(query) - && (maxScore - parseInfo.getScore()) / maxScore <= candidateThreshold - && checkSatisfyOtherRules(query, candidateQueries)) { - selectedQueries.add(query); - } - log.info("candidate query (Model={}, queryMode={}) with score={}", - parseInfo.getModelName(), parseInfo.getQueryMode(), parseInfo.getScore()); - }); - } - } - log.debug("pick after [{}]", selectedQueries.stream().collect(Collectors.toList())); - return selectedQueries; - } - - private boolean checkSatisfyOtherRules(SemanticQuery semanticQuery, List candidateQueries) { - if (!semanticQuery.getQueryMode().equals(MetricModelQuery.QUERY_MODE)) { - return true; - } - for (SemanticQuery candidateQuery : candidateQueries) { - if (candidateQuery.getQueryMode().equals(MetricEntityQuery.QUERY_MODE) - && semanticQuery.getParseInfo().getScore() == candidateQuery.getParseInfo().getScore()) { - return false; - } - } - return true; - } - - private boolean checkFullyInherited(SemanticQuery query) { - SemanticParseInfo parseInfo = query.getParseInfo(); - if (!(query instanceof RuleSemanticQuery)) { - return false; - } - - for (SchemaElementMatch match : parseInfo.getElementMatches()) { - if (!match.isInherited()) { - return false; - } - } - if (parseInfo.getDateInfo() != null && !parseInfo.getDateInfo().isInherited()) { - return false; - } - - return true; - } -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryRanker.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryRanker.java new file mode 100644 index 000000000..c4687d31d --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryRanker.java @@ -0,0 +1,64 @@ +package com.tencent.supersonic.chat.query; + +import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; +import java.util.List; +import java.util.ArrayList; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +@Slf4j +@Component +public class QueryRanker { + + @Value("${candidate.top.size:5}") + private int candidateTopSize; + + public List rank(List candidateQueries) { + log.debug("pick before [{}]", candidateQueries); + if (CollectionUtils.isEmpty(candidateQueries)) { + return candidateQueries; + } + List selectedQueries = new ArrayList<>(); + if (candidateQueries.size() == 1) { + selectedQueries.addAll(candidateQueries); + } else { + selectedQueries = getTopCandidateQuery(candidateQueries); + } + log.debug("pick after [{}]", selectedQueries); + return selectedQueries; + } + + public List getTopCandidateQuery(List semanticQueries) { + return semanticQueries.stream() + .filter(query -> !checkFullyInherited(query)) + .sorted((o1, o2) -> { + if (o1.getParseInfo().getScore() < o2.getParseInfo().getScore()) { + return 1; + } else if (o1.getParseInfo().getScore() > o2.getParseInfo().getScore()) { + return -1; + } + return 0; + }).limit(candidateTopSize) + .collect(Collectors.toList()); + } + + private boolean checkFullyInherited(SemanticQuery query) { + SemanticParseInfo parseInfo = query.getParseInfo(); + if (!(query instanceof RuleSemanticQuery)) { + return false; + } + + for (SchemaElementMatch match : parseInfo.getElementMatches()) { + if (!match.isInherited()) { + return false; + } + } + return parseInfo.getDateInfo() == null || parseInfo.getDateInfo().isInherited(); + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/QuerySelector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/QuerySelector.java deleted file mode 100644 index 51ecf9f50..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/QuerySelector.java +++ /dev/null @@ -1,14 +0,0 @@ -package com.tencent.supersonic.chat.query; - -import com.tencent.supersonic.chat.api.component.SemanticQuery; -import com.tencent.supersonic.chat.api.pojo.request.QueryReq; - -import java.util.List; - -/** - * This interface defines the contract for a selector that picks the most suitable semantic query. - **/ -public interface QuerySelector { - - List select(List candidateQueries, QueryReq queryReq); -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java index 70fce1e3f..334ae51bb 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java @@ -19,7 +19,7 @@ public class EntityInfoParseResponder implements ParseResponder { @Override public void fillResponse(ParseResp parseResp, QueryContext queryContext, List chatParseDOS) { - List selectedParses = parseResp.getSelectedParses(); + List selectedParses = parseResp.getCandidateParses(); if (CollectionUtils.isEmpty(selectedParses)) { return; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/SqlInfoParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/SqlInfoParseResponder.java index 84b12d339..676e224ab 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/SqlInfoParseResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/SqlInfoParseResponder.java @@ -24,7 +24,6 @@ public class SqlInfoParseResponder implements ParseResponder { List chatParseDOS) { QueryReq queryReq = queryContext.getRequest(); Long startTime = System.currentTimeMillis(); - addSqlInfo(queryReq, parseResp.getSelectedParses()); addSqlInfo(queryReq, parseResp.getCandidateParses()); parseResp.setParseTimeCost(new ParseTimeCostDO()); parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime); @@ -32,7 +31,6 @@ public class SqlInfoParseResponder implements ParseResponder { Map chatParseDOMap = chatParseDOS.stream() .collect(Collectors.toMap(ChatParseDO::getParseId, Function.identity(), (oldValue, newValue) -> newValue)); - updateParseInfo(chatParseDOMap, parseResp.getSelectedParses()); updateParseInfo(chatParseDOMap, parseResp.getCandidateParses()); } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/ParseInfoService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/ParseInfoService.java index 4df39777a..1a1b2614f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/ParseInfoService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/ParseInfoService.java @@ -1,16 +1,9 @@ package com.tencent.supersonic.chat.service; -import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import java.util.List; public interface ParseInfoService { - List getTopCandidateParseInfo(List selectedParses, - List candidateParses); - - List sortParseInfo(List semanticQueries); - void updateParseInfo(SemanticParseInfo parseInfo); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java index d51b69956..4eb2de5ec 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java @@ -225,8 +225,7 @@ public class ChatServiceImpl implements ChatService { @Override public List batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult) { List candidateParses = parseResult.getCandidateParses(); - List selectedParses = parseResult.getSelectedParses(); - return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses, selectedParses); + return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses); } @Override diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java index 489307a37..b7ec8b744 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java @@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.service.impl; import com.google.common.collect.Lists; -import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; @@ -20,7 +19,6 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.knowledge.service.SchemaService; import java.util.ArrayList; import java.util.Arrays; -import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -31,45 +29,12 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; -import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; @Slf4j @Service public class ParserInfoServiceImpl implements ParseInfoService { - @Value("${candidate.top.size:5}") - private int candidateTopSize; - - public List getTopCandidateParseInfo(List selectedParses, - List candidateParses) { - if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) { - return candidateParses; - } - int selectParseSize = selectedParses.size(); - Set selectParseScoreSet = selectedParses.stream() - .map(SemanticParseInfo::getScore).collect(Collectors.toSet()); - int candidateParseSize = candidateTopSize - selectParseSize; - candidateParses = candidateParses.stream() - .filter(candidateParse -> !selectParseScoreSet.contains(candidateParse.getScore())) - .collect(Collectors.toList()); - SemanticParseInfo semanticParseInfo = selectedParses.get(0); - Long modelId = semanticParseInfo.getModelId(); - if (modelId == null || modelId <= 0) { - return candidateParses; - } - return candidateParses.stream() - .sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId))) - .limit(candidateParseSize) - .collect(Collectors.toList()); - } - - public List sortParseInfo(List semanticQueries) { - return semanticQueries.stream() - .map(SemanticQuery::getParseInfo) - .sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed()) - .collect(Collectors.toList()); - } public void updateParseInfo(SemanticParseInfo parseInfo) { SqlInfo sqlInfo = parseInfo.getSqlInfo(); @@ -81,9 +46,11 @@ public class ParserInfoServiceImpl implements ParseInfoService { List expressions = SqlParserSelectHelper.getFilterExpression(logicSql); //set dataInfo try { - if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) { + if (!CollectionUtils.isEmpty(expressions)) { DateConf dateInfo = getDateInfo(expressions); - parseInfo.setDateInfo(dateInfo); + if (dateInfo != null && parseInfo.getDateInfo() == null) { + parseInfo.setDateInfo(dateInfo); + } } } catch (Exception e) { log.error("set dateInfo error :", e); @@ -103,10 +70,10 @@ public class ParserInfoServiceImpl implements ParseInfoService { if (Objects.isNull(semanticSchema)) { return; } - List allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL())); - - Set metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics()); - parseInfo.setMetrics(metrics); + //cannot use metrics in sql to override parse info + //List allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL())); + //Set metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics()); + //parseInfo.setMetrics(metrics); if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) { parseInfo.setNativeQuery(false); @@ -167,8 +134,8 @@ public class ParserInfoServiceImpl implements ParseInfoService { List dateExpressions = filterExpressions.stream() .filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName())) .collect(Collectors.toList()); - if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) { - return new DateConf(); + if (CollectionUtils.isEmpty(dateExpressions)) { + return null; } DateConf dateInfo = new DateConf(); dateInfo.setDateMode(DateMode.BETWEEN); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 153447cf0..5d681af5c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -27,7 +27,7 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.persistence.dataobject.CostType; import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO; import com.tencent.supersonic.chat.query.QueryManager; -import com.tencent.supersonic.chat.query.QuerySelector; +import com.tencent.supersonic.chat.query.QueryRanker; import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery; import com.tencent.supersonic.chat.responder.execute.ExecuteResponder; import com.tencent.supersonic.chat.responder.parse.ParseResponder; @@ -105,13 +105,14 @@ public class QueryServiceImpl implements QueryService { private SolvedQueryManager solvedQueryManager; @Autowired private ParseInfoService parseInfoService; + @Autowired + private QueryRanker queryRanker; @Value("${time.threshold: 100}") private Integer timeThreshold; private List schemaMappers = ComponentFactory.getSchemaMappers(); private List semanticParsers = ComponentFactory.getSemanticParsers(); - private QuerySelector querySelector = ComponentFactory.getQuerySelector(); private List parseResponders = ComponentFactory.getParseResponders(); private List executeResponders = ComponentFactory.getExecuteResponders(); private List semanticCorrectors = ComponentFactory.getSqlCorrections(); @@ -157,18 +158,12 @@ public class QueryServiceImpl implements QueryService { ParseResp parseResult; List chatParseDOS = Lists.newArrayList(); if (candidateQueries.size() > 0) { - List selectedQueries = querySelector.select(candidateQueries, queryReq); - List selectedParses = parseInfoService.sortParseInfo(selectedQueries); - List candidateParses = parseInfoService.sortParseInfo(candidateQueries); - candidateParses = parseInfoService.getTopCandidateParseInfo(selectedParses, candidateParses); - candidateQueries.forEach(semanticQuery -> parseInfoService.updateParseInfo(semanticQuery.getParseInfo())); - - parseResult = ParseResp.builder() - .chatId(queryReq.getChatId()) - .queryText(queryReq.getQueryText()) - .state(selectedParses.size() > 1 ? ParseResp.ParseState.PENDING : ParseResp.ParseState.COMPLETED) - .selectedParses(selectedParses) - .candidateParses(candidateParses) + candidateQueries = queryRanker.rank(candidateQueries); + List candidateParses = candidateQueries.stream() + .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); + candidateParses.forEach(parseInfo -> parseInfoService.updateParseInfo(parseInfo)); + parseResult = ParseResp.builder().chatId(queryReq.getChatId()).queryText(queryReq.getQueryText()) + .state(ParseResp.ParseState.COMPLETED).candidateParses(candidateParses) .build(); chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult); } else { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java index f03b6908c..e4538827a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java @@ -10,7 +10,6 @@ import java.util.List; import java.util.Objects; import com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver; -import com.tencent.supersonic.chat.query.QuerySelector; import com.tencent.supersonic.chat.responder.execute.ExecuteResponder; import com.tencent.supersonic.chat.responder.parse.ParseResponder; import org.apache.commons.collections.CollectionUtils; @@ -24,7 +23,6 @@ public class ComponentFactory { private static SemanticInterpreter semanticInterpreter; private static List parseResponders = new ArrayList<>(); private static List executeResponders = new ArrayList<>(); - private static QuerySelector querySelector; private static ModelResolver modelResolver; public static List getSchemaMappers() { return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers; @@ -59,12 +57,6 @@ public class ComponentFactory { semanticInterpreter = layer; } - public static QuerySelector getQuerySelector() { - if (Objects.isNull(querySelector)) { - querySelector = init(QuerySelector.class); - } - return querySelector; - } public static ModelResolver getModelResolver() { if (Objects.isNull(modelResolver)) { diff --git a/launchers/chat/src/main/resources/META-INF/spring.factories b/launchers/chat/src/main/resources/META-INF/spring.factories index 669b18857..28f1a0e2f 100644 --- a/launchers/chat/src/main/resources/META-INF/spring.factories +++ b/launchers/chat/src/main/resources/META-INF/spring.factories @@ -20,9 +20,6 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\ com.tencent.supersonic.chat.api.component.SemanticInterpreter=\ com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter -com.tencent.supersonic.chat.query.QuerySelector=\ - com.tencent.supersonic.chat.query.HeuristicQuerySelector - com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver=\ com.tencent.supersonic.chat.parser.llm.s2sql.HeuristicModelResolver diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java index dbf465e28..1f3ca40da 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java @@ -68,12 +68,13 @@ public class ConfigureDemo implements ApplicationListener ExecuteQueryReq executeReq = ExecuteQueryReq.builder().build(); executeReq.setQueryId(parseResp.getQueryId()); - executeReq.setParseId(parseResp.getSelectedParses().get(0).getId()); + executeReq.setParseId(parseResp.getCandidateParses().get(0).getId()); executeReq.setQueryText(queryRequest.getQueryText()); - executeReq.setParseInfo(parseResp.getSelectedParses().get(0)); + executeReq.setParseInfo(parseResp.getCandidateParses().get(0)); executeReq.setChatId(parseResp.getChatId()); executeReq.setUser(queryRequest.getUser()); executeReq.setAgentId(1); + executeReq.setSaveAnswer(true); queryService.performExecution(executeReq); } diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index 395a4d9c0..531fb6bb7 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -21,9 +21,6 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\ com.tencent.supersonic.chat.api.component.SemanticInterpreter=\ com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter -com.tencent.supersonic.chat.query.QuerySelector=\ - com.tencent.supersonic.chat.query.HeuristicQuerySelector - com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver=\ com.tencent.supersonic.chat.parser.llm.s2sql.HeuristicModelResolver diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/BaseQueryTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/BaseQueryTest.java index 2ff6a901b..416f94c10 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/BaseQueryTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/BaseQueryTest.java @@ -53,11 +53,12 @@ public class BaseQueryTest { ExecuteQueryReq request = ExecuteQueryReq.builder() .queryId(parseResp.getQueryId()) - .parseId(parseResp.getSelectedParses().get(0).getId()) + .parseId(parseResp.getCandidateParses().get(0).getId()) .chatId(parseResp.getChatId()) .queryText(parseResp.getQueryText()) .user(DataUtils.getUser()) - .parseInfo(parseResp.getSelectedParses().get(0)) + .parseInfo(parseResp.getCandidateParses().get(0)) + .saveAnswer(true) .build(); return queryService.performExecution(request); @@ -68,11 +69,12 @@ public class BaseQueryTest { ExecuteQueryReq request = ExecuteQueryReq.builder() .queryId(parseResp.getQueryId()) - .parseId(parseResp.getSelectedParses().get(0).getId()) + .parseId(parseResp.getCandidateParses().get(0).getId()) .chatId(parseResp.getChatId()) .queryText(parseResp.getQueryText()) .user(DataUtils.getUser()) - .parseInfo(parseResp.getSelectedParses().get(0)) + .parseInfo(parseResp.getCandidateParses().get(0)) + .saveAnswer(true) .build(); QueryResult result = queryService.performExecution(request); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java index 88505ebb2..15e9656a7 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java @@ -57,8 +57,8 @@ public class MetricInterpretTest { .chatId(parseResp.getChatId()) .queryId(parseResp.getQueryId()) .queryText(parseResp.getQueryText()) - .parseInfo(parseResp.getSelectedParses().get(0)) - .parseId(parseResp.getSelectedParses().get(0).getId()) + .parseInfo(parseResp.getCandidateParses().get(0)) + .parseId(parseResp.getCandidateParses().get(0).getId()) .build(); QueryResult queryResult = queryService.performExecution(executeReq); Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistantMessage()); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java index f02ea3df6..9575cef5b 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java @@ -58,8 +58,8 @@ public class MetricQueryTest extends BaseQueryTest { //agent only support METRIC_ENTITY, METRIC_FILTER MockConfiguration.mockAgent(agentService); ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getAgent().getId()); - Assert.assertNotNull(parseResp.getSelectedParses()); - List queryModes = parseResp.getSelectedParses().stream() + Assert.assertNotNull(parseResp.getCandidateParses()); + List queryModes = parseResp.getCandidateParses().stream() .map(SemanticParseInfo::getQueryMode).collect(Collectors.toList()); Assert.assertTrue(queryModes.contains("METRIC_FILTER")); } @@ -88,7 +88,7 @@ public class MetricQueryTest extends BaseQueryTest { //agent only support METRIC_ENTITY, METRIC_FILTER MockConfiguration.mockAgent(agentService); ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getAgent().getId()); - List queryModes = parseResp.getSelectedParses().stream() + List queryModes = parseResp.getCandidateParses().stream() .map(SemanticParseInfo::getQueryMode).collect(Collectors.toList()); Assert.assertTrue(queryModes.contains("METRIC_MODEL")); } @@ -123,7 +123,7 @@ public class MetricQueryTest extends BaseQueryTest { expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE); - + expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户名")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); List list = new ArrayList<>(); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java index ac8a577ea..0119f4cfe 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java @@ -44,7 +44,7 @@ public class PluginRecognizeTest extends BasePluginTest { .chatId(parseResp.getChatId()) .queryId(parseResp.getQueryId()) .queryText(parseResp.getQueryText()) - .parseInfo(parseResp.getSelectedParses().get(0)) + .parseInfo(parseResp.getCandidateParses().get(0)) .build(); QueryResult queryResult = queryService.performExecution(executeReq); @@ -69,7 +69,7 @@ public class PluginRecognizeTest extends BasePluginTest { .chatId(parseResp.getChatId()) .queryId(parseResp.getQueryId()) .queryText(parseResp.getQueryText()) - .parseInfo(parseResp.getSelectedParses().get(0)) + .parseInfo(parseResp.getCandidateParses().get(0)) .build(); QueryResult queryResult = queryService.performExecution(executeReq); @@ -84,8 +84,8 @@ public class PluginRecognizeTest extends BasePluginTest { QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(1000, "alice最近的访问情况怎么样", DataUtils.getAgent().getId()); ParseResp parseResp = queryService.performParsing(queryContextReq); - Assert.assertTrue(parseResp.getSelectedParses() != null - && parseResp.getSelectedParses().size() > 0); + Assert.assertTrue(parseResp.getCandidateParses() != null + && parseResp.getCandidateParses().size() > 0); } } diff --git a/launchers/standalone/src/test/resources/META-INF/spring.factories b/launchers/standalone/src/test/resources/META-INF/spring.factories index 3a442fdf7..12d09fbb4 100644 --- a/launchers/standalone/src/test/resources/META-INF/spring.factories +++ b/launchers/standalone/src/test/resources/META-INF/spring.factories @@ -12,9 +12,6 @@ com.tencent.supersonic.chat.api.component.QueryProcessor=\ com.tencent.supersonic.chat.api.component.SemanticInterpreter=\ com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter -com.tencent.supersonic.chat.query.QuerySelector=\ - com.tencent.supersonic.chat.query.HeuristicQuerySelector - com.tencent.supersonic.chat.application.query.DomainResolver=\ com.tencent.supersonic.chat.application.query.HeuristicDomainResolver diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/listener/MetaEmbeddingListener.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/listener/MetaEmbeddingListener.java index 6491cd29a..0404af90c 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/listener/MetaEmbeddingListener.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/listener/MetaEmbeddingListener.java @@ -12,6 +12,7 @@ import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationListener; +import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; @@ -24,6 +25,7 @@ public class MetaEmbeddingListener implements ApplicationListener { @Autowired private EmbeddingUtils embeddingUtils; + @Async @Override public void onApplicationEvent(DataEvent event) { if (CollectionUtils.isEmpty(event.getDataItems())) {