mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
Feature/Refactor querySelect to queryRanker and fix some errors in integration tests (#369)
* (fix) (chat) fix the context saving failure caused by the loss of default values caused by @builder * (fix) (chat) fix date and metrics result in parse info in integration test * (improvement) (chat) refactor querySelect to queryRanker --------- Co-authored-by: jolunoluo
This commit is contained in:
@@ -13,8 +13,8 @@ public class ExecuteQueryReq {
|
||||
private Integer agentId;
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private Long queryId = 7L;
|
||||
private Integer parseId = 2;
|
||||
private Long queryId;
|
||||
private Integer parseId;
|
||||
private SemanticParseInfo parseInfo;
|
||||
private boolean saveAnswer = true;
|
||||
private boolean saveAnswer;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
import lombok.Getter;
|
||||
@@ -19,8 +20,8 @@ public class ParseResp {
|
||||
private String queryText;
|
||||
private Long queryId;
|
||||
private ParseState state;
|
||||
private List<SemanticParseInfo> selectedParses;
|
||||
private List<SemanticParseInfo> candidateParses;
|
||||
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
|
||||
private List<SemanticParseInfo> candidateParses = Lists.newArrayList();
|
||||
private List<SolvedQueryRecallResp> similarSolvedQuery;
|
||||
private ParseTimeCostDO parseTimeCost;
|
||||
|
||||
@@ -29,4 +30,11 @@ public class ParseResp {
|
||||
PENDING,
|
||||
FAILED
|
||||
}
|
||||
|
||||
public List<SemanticParseInfo> getSelectedParses() {
|
||||
selectedParses = Lists.newArrayList();
|
||||
selectedParses.addAll(candidateParses);
|
||||
candidateParses.clear();
|
||||
return selectedParses;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,8 +31,7 @@ public interface ChatQueryRepository {
|
||||
|
||||
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
|
||||
ParseResp parseResult,
|
||||
List<SemanticParseInfo> candidateParses,
|
||||
List<SemanticParseInfo> selectedParses);
|
||||
List<SemanticParseInfo> candidateParses);
|
||||
|
||||
public ChatParseDO getParseInfo(Long questionId, int parseId);
|
||||
|
||||
|
||||
@@ -133,13 +133,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
@Override
|
||||
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
|
||||
ParseResp parseResult,
|
||||
List<SemanticParseInfo> candidateParses,
|
||||
List<SemanticParseInfo> selectedParses) {
|
||||
List<SemanticParseInfo> candidateParses) {
|
||||
Long queryId = createChatParse(parseResult, chatCtx, queryReq);
|
||||
List<ChatParseDO> chatParseDOList = new ArrayList<>();
|
||||
log.info("candidateParses size:{},selectedParses size:{}", candidateParses.size(), selectedParses.size());
|
||||
getChatParseDO(chatCtx, queryReq, queryId, 0, 1, candidateParses, chatParseDOList);
|
||||
getChatParseDO(chatCtx, queryReq, queryId, candidateParses.size(), 0, selectedParses, chatParseDOList);
|
||||
getChatParseDO(chatCtx, queryReq, queryId, 0, candidateParses, chatParseDOList);
|
||||
chatParseMapper.batchSaveParseInfo(chatParseDOList);
|
||||
return chatParseDOList;
|
||||
}
|
||||
@@ -151,7 +148,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
}
|
||||
}
|
||||
|
||||
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base, int isCandidate,
|
||||
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base,
|
||||
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
|
||||
for (int i = 0; i < parses.size(); i++) {
|
||||
ChatParseDO chatParseDO = new ChatParseDO();
|
||||
@@ -160,7 +157,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
chatParseDO.setQuestionId(queryId);
|
||||
chatParseDO.setQueryText(queryReq.getQueryText());
|
||||
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
|
||||
chatParseDO.setIsCandidate(isCandidate);
|
||||
chatParseDO.setIsCandidate(1);
|
||||
if (i == 0) {
|
||||
chatParseDO.setIsCandidate(0);
|
||||
}
|
||||
chatParseDO.setParseId(base + i + 1);
|
||||
chatParseDO.setCreateTime(new java.util.Date());
|
||||
chatParseDO.setUserName(queryReq.getUser().getName());
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
package com.tencent.supersonic.chat.query;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.OptionalDouble;
|
||||
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class HeuristicQuerySelector implements QuerySelector {
|
||||
|
||||
@Override
|
||||
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) {
|
||||
log.debug("pick before [{}]", candidateQueries.stream().collect(Collectors.toList()));
|
||||
List<SemanticQuery> selectedQueries = new ArrayList<>();
|
||||
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
|
||||
Double candidateThreshold = optimizationConfig.getCandidateThreshold();
|
||||
if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) {
|
||||
selectedQueries.addAll(candidateQueries);
|
||||
} else {
|
||||
OptionalDouble maxScoreOp = candidateQueries.stream().mapToDouble(
|
||||
q -> q.getParseInfo().getScore()).max();
|
||||
if (maxScoreOp.isPresent()) {
|
||||
double maxScore = maxScoreOp.getAsDouble();
|
||||
|
||||
candidateQueries.stream().forEach(query -> {
|
||||
SemanticParseInfo parseInfo = query.getParseInfo();
|
||||
if (!checkFullyInherited(query)
|
||||
&& (maxScore - parseInfo.getScore()) / maxScore <= candidateThreshold
|
||||
&& checkSatisfyOtherRules(query, candidateQueries)) {
|
||||
selectedQueries.add(query);
|
||||
}
|
||||
log.info("candidate query (Model={}, queryMode={}) with score={}",
|
||||
parseInfo.getModelName(), parseInfo.getQueryMode(), parseInfo.getScore());
|
||||
});
|
||||
}
|
||||
}
|
||||
log.debug("pick after [{}]", selectedQueries.stream().collect(Collectors.toList()));
|
||||
return selectedQueries;
|
||||
}
|
||||
|
||||
private boolean checkSatisfyOtherRules(SemanticQuery semanticQuery, List<SemanticQuery> candidateQueries) {
|
||||
if (!semanticQuery.getQueryMode().equals(MetricModelQuery.QUERY_MODE)) {
|
||||
return true;
|
||||
}
|
||||
for (SemanticQuery candidateQuery : candidateQueries) {
|
||||
if (candidateQuery.getQueryMode().equals(MetricEntityQuery.QUERY_MODE)
|
||||
&& semanticQuery.getParseInfo().getScore() == candidateQuery.getParseInfo().getScore()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private boolean checkFullyInherited(SemanticQuery query) {
|
||||
SemanticParseInfo parseInfo = query.getParseInfo();
|
||||
if (!(query instanceof RuleSemanticQuery)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
|
||||
if (!match.isInherited()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (parseInfo.getDateInfo() != null && !parseInfo.getDateInfo().isInherited()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package com.tencent.supersonic.chat.query;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class QueryRanker {
|
||||
|
||||
@Value("${candidate.top.size:5}")
|
||||
private int candidateTopSize;
|
||||
|
||||
public List<SemanticQuery> rank(List<SemanticQuery> candidateQueries) {
|
||||
log.debug("pick before [{}]", candidateQueries);
|
||||
if (CollectionUtils.isEmpty(candidateQueries)) {
|
||||
return candidateQueries;
|
||||
}
|
||||
List<SemanticQuery> selectedQueries = new ArrayList<>();
|
||||
if (candidateQueries.size() == 1) {
|
||||
selectedQueries.addAll(candidateQueries);
|
||||
} else {
|
||||
selectedQueries = getTopCandidateQuery(candidateQueries);
|
||||
}
|
||||
log.debug("pick after [{}]", selectedQueries);
|
||||
return selectedQueries;
|
||||
}
|
||||
|
||||
public List<SemanticQuery> getTopCandidateQuery(List<SemanticQuery> semanticQueries) {
|
||||
return semanticQueries.stream()
|
||||
.filter(query -> !checkFullyInherited(query))
|
||||
.sorted((o1, o2) -> {
|
||||
if (o1.getParseInfo().getScore() < o2.getParseInfo().getScore()) {
|
||||
return 1;
|
||||
} else if (o1.getParseInfo().getScore() > o2.getParseInfo().getScore()) {
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}).limit(candidateTopSize)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private boolean checkFullyInherited(SemanticQuery query) {
|
||||
SemanticParseInfo parseInfo = query.getParseInfo();
|
||||
if (!(query instanceof RuleSemanticQuery)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
|
||||
if (!match.isInherited()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return parseInfo.getDateInfo() == null || parseInfo.getDateInfo().isInherited();
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package com.tencent.supersonic.chat.query;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* This interface defines the contract for a selector that picks the most suitable semantic query.
|
||||
**/
|
||||
public interface QuerySelector {
|
||||
|
||||
List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq);
|
||||
}
|
||||
@@ -19,7 +19,7 @@ public class EntityInfoParseResponder implements ParseResponder {
|
||||
@Override
|
||||
public void fillResponse(ParseResp parseResp, QueryContext queryContext,
|
||||
List<ChatParseDO> chatParseDOS) {
|
||||
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
||||
List<SemanticParseInfo> selectedParses = parseResp.getCandidateParses();
|
||||
if (CollectionUtils.isEmpty(selectedParses)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ public class SqlInfoParseResponder implements ParseResponder {
|
||||
List<ChatParseDO> chatParseDOS) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
Long startTime = System.currentTimeMillis();
|
||||
addSqlInfo(queryReq, parseResp.getSelectedParses());
|
||||
addSqlInfo(queryReq, parseResp.getCandidateParses());
|
||||
parseResp.setParseTimeCost(new ParseTimeCostDO());
|
||||
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime);
|
||||
@@ -32,7 +31,6 @@ public class SqlInfoParseResponder implements ParseResponder {
|
||||
Map<Integer, ChatParseDO> chatParseDOMap = chatParseDOS.stream()
|
||||
.collect(Collectors.toMap(ChatParseDO::getParseId,
|
||||
Function.identity(), (oldValue, newValue) -> newValue));
|
||||
updateParseInfo(chatParseDOMap, parseResp.getSelectedParses());
|
||||
updateParseInfo(chatParseDOMap, parseResp.getCandidateParses());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
package com.tencent.supersonic.chat.service;
|
||||
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import java.util.List;
|
||||
|
||||
public interface ParseInfoService {
|
||||
|
||||
List<SemanticParseInfo> getTopCandidateParseInfo(List<SemanticParseInfo> selectedParses,
|
||||
List<SemanticParseInfo> candidateParses);
|
||||
|
||||
List<SemanticParseInfo> sortParseInfo(List<SemanticQuery> semanticQueries);
|
||||
|
||||
void updateParseInfo(SemanticParseInfo parseInfo);
|
||||
|
||||
}
|
||||
|
||||
@@ -225,8 +225,7 @@ public class ChatServiceImpl implements ChatService {
|
||||
@Override
|
||||
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult) {
|
||||
List<SemanticParseInfo> candidateParses = parseResult.getCandidateParses();
|
||||
List<SemanticParseInfo> selectedParses = parseResult.getSelectedParses();
|
||||
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses, selectedParses);
|
||||
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
package com.tencent.supersonic.chat.service.impl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
@@ -20,7 +19,6 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -31,45 +29,12 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
|
||||
@Value("${candidate.top.size:5}")
|
||||
private int candidateTopSize;
|
||||
|
||||
public List<SemanticParseInfo> getTopCandidateParseInfo(List<SemanticParseInfo> selectedParses,
|
||||
List<SemanticParseInfo> candidateParses) {
|
||||
if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) {
|
||||
return candidateParses;
|
||||
}
|
||||
int selectParseSize = selectedParses.size();
|
||||
Set<Double> selectParseScoreSet = selectedParses.stream()
|
||||
.map(SemanticParseInfo::getScore).collect(Collectors.toSet());
|
||||
int candidateParseSize = candidateTopSize - selectParseSize;
|
||||
candidateParses = candidateParses.stream()
|
||||
.filter(candidateParse -> !selectParseScoreSet.contains(candidateParse.getScore()))
|
||||
.collect(Collectors.toList());
|
||||
SemanticParseInfo semanticParseInfo = selectedParses.get(0);
|
||||
Long modelId = semanticParseInfo.getModelId();
|
||||
if (modelId == null || modelId <= 0) {
|
||||
return candidateParses;
|
||||
}
|
||||
return candidateParses.stream()
|
||||
.sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId)))
|
||||
.limit(candidateParseSize)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public List<SemanticParseInfo> sortParseInfo(List<SemanticQuery> semanticQueries) {
|
||||
return semanticQueries.stream()
|
||||
.map(SemanticQuery::getParseInfo)
|
||||
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public void updateParseInfo(SemanticParseInfo parseInfo) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
@@ -81,9 +46,11 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(logicSql);
|
||||
//set dataInfo
|
||||
try {
|
||||
if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) {
|
||||
if (!CollectionUtils.isEmpty(expressions)) {
|
||||
DateConf dateInfo = getDateInfo(expressions);
|
||||
parseInfo.setDateInfo(dateInfo);
|
||||
if (dateInfo != null && parseInfo.getDateInfo() == null) {
|
||||
parseInfo.setDateInfo(dateInfo);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("set dateInfo error :", e);
|
||||
@@ -103,10 +70,10 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
if (Objects.isNull(semanticSchema)) {
|
||||
return;
|
||||
}
|
||||
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
|
||||
|
||||
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
|
||||
parseInfo.setMetrics(metrics);
|
||||
//cannot use metrics in sql to override parse info
|
||||
//List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
|
||||
//Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
|
||||
//parseInfo.setMetrics(metrics);
|
||||
|
||||
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
|
||||
parseInfo.setNativeQuery(false);
|
||||
@@ -167,8 +134,8 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
List<FilterExpression> dateExpressions = filterExpressions.stream()
|
||||
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
|
||||
.collect(Collectors.toList());
|
||||
if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) {
|
||||
return new DateConf();
|
||||
if (CollectionUtils.isEmpty(dateExpressions)) {
|
||||
return null;
|
||||
}
|
||||
DateConf dateInfo = new DateConf();
|
||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||
|
||||
@@ -27,7 +27,7 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.QuerySelector;
|
||||
import com.tencent.supersonic.chat.query.QueryRanker;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
|
||||
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
|
||||
@@ -105,13 +105,14 @@ public class QueryServiceImpl implements QueryService {
|
||||
private SolvedQueryManager solvedQueryManager;
|
||||
@Autowired
|
||||
private ParseInfoService parseInfoService;
|
||||
@Autowired
|
||||
private QueryRanker queryRanker;
|
||||
|
||||
@Value("${time.threshold: 100}")
|
||||
private Integer timeThreshold;
|
||||
|
||||
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
|
||||
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
|
||||
private QuerySelector querySelector = ComponentFactory.getQuerySelector();
|
||||
private List<ParseResponder> parseResponders = ComponentFactory.getParseResponders();
|
||||
private List<ExecuteResponder> executeResponders = ComponentFactory.getExecuteResponders();
|
||||
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSqlCorrections();
|
||||
@@ -157,18 +158,12 @@ public class QueryServiceImpl implements QueryService {
|
||||
ParseResp parseResult;
|
||||
List<ChatParseDO> chatParseDOS = Lists.newArrayList();
|
||||
if (candidateQueries.size() > 0) {
|
||||
List<SemanticQuery> selectedQueries = querySelector.select(candidateQueries, queryReq);
|
||||
List<SemanticParseInfo> selectedParses = parseInfoService.sortParseInfo(selectedQueries);
|
||||
List<SemanticParseInfo> candidateParses = parseInfoService.sortParseInfo(candidateQueries);
|
||||
candidateParses = parseInfoService.getTopCandidateParseInfo(selectedParses, candidateParses);
|
||||
candidateQueries.forEach(semanticQuery -> parseInfoService.updateParseInfo(semanticQuery.getParseInfo()));
|
||||
|
||||
parseResult = ParseResp.builder()
|
||||
.chatId(queryReq.getChatId())
|
||||
.queryText(queryReq.getQueryText())
|
||||
.state(selectedParses.size() > 1 ? ParseResp.ParseState.PENDING : ParseResp.ParseState.COMPLETED)
|
||||
.selectedParses(selectedParses)
|
||||
.candidateParses(candidateParses)
|
||||
candidateQueries = queryRanker.rank(candidateQueries);
|
||||
List<SemanticParseInfo> candidateParses = candidateQueries.stream()
|
||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||
candidateParses.forEach(parseInfo -> parseInfoService.updateParseInfo(parseInfo));
|
||||
parseResult = ParseResp.builder().chatId(queryReq.getChatId()).queryText(queryReq.getQueryText())
|
||||
.state(ParseResp.ParseState.COMPLETED).candidateParses(candidateParses)
|
||||
.build();
|
||||
chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult);
|
||||
} else {
|
||||
|
||||
@@ -10,7 +10,6 @@ import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver;
|
||||
import com.tencent.supersonic.chat.query.QuerySelector;
|
||||
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
|
||||
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
@@ -24,7 +23,6 @@ public class ComponentFactory {
|
||||
private static SemanticInterpreter semanticInterpreter;
|
||||
private static List<ParseResponder> parseResponders = new ArrayList<>();
|
||||
private static List<ExecuteResponder> executeResponders = new ArrayList<>();
|
||||
private static QuerySelector querySelector;
|
||||
private static ModelResolver modelResolver;
|
||||
public static List<SchemaMapper> getSchemaMappers() {
|
||||
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers;
|
||||
@@ -59,12 +57,6 @@ public class ComponentFactory {
|
||||
semanticInterpreter = layer;
|
||||
}
|
||||
|
||||
public static QuerySelector getQuerySelector() {
|
||||
if (Objects.isNull(querySelector)) {
|
||||
querySelector = init(QuerySelector.class);
|
||||
}
|
||||
return querySelector;
|
||||
}
|
||||
|
||||
public static ModelResolver getModelResolver() {
|
||||
if (Objects.isNull(modelResolver)) {
|
||||
|
||||
@@ -20,9 +20,6 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
||||
com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter
|
||||
|
||||
com.tencent.supersonic.chat.query.QuerySelector=\
|
||||
com.tencent.supersonic.chat.query.HeuristicQuerySelector
|
||||
|
||||
com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver=\
|
||||
com.tencent.supersonic.chat.parser.llm.s2sql.HeuristicModelResolver
|
||||
|
||||
|
||||
@@ -68,12 +68,13 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
|
||||
|
||||
ExecuteQueryReq executeReq = ExecuteQueryReq.builder().build();
|
||||
executeReq.setQueryId(parseResp.getQueryId());
|
||||
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
|
||||
executeReq.setParseId(parseResp.getCandidateParses().get(0).getId());
|
||||
executeReq.setQueryText(queryRequest.getQueryText());
|
||||
executeReq.setParseInfo(parseResp.getSelectedParses().get(0));
|
||||
executeReq.setParseInfo(parseResp.getCandidateParses().get(0));
|
||||
executeReq.setChatId(parseResp.getChatId());
|
||||
executeReq.setUser(queryRequest.getUser());
|
||||
executeReq.setAgentId(1);
|
||||
executeReq.setSaveAnswer(true);
|
||||
queryService.performExecution(executeReq);
|
||||
}
|
||||
|
||||
|
||||
@@ -21,9 +21,6 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
||||
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter
|
||||
|
||||
com.tencent.supersonic.chat.query.QuerySelector=\
|
||||
com.tencent.supersonic.chat.query.HeuristicQuerySelector
|
||||
|
||||
com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver=\
|
||||
com.tencent.supersonic.chat.parser.llm.s2sql.HeuristicModelResolver
|
||||
|
||||
|
||||
@@ -53,11 +53,12 @@ public class BaseQueryTest {
|
||||
|
||||
ExecuteQueryReq request = ExecuteQueryReq.builder()
|
||||
.queryId(parseResp.getQueryId())
|
||||
.parseId(parseResp.getSelectedParses().get(0).getId())
|
||||
.parseId(parseResp.getCandidateParses().get(0).getId())
|
||||
.chatId(parseResp.getChatId())
|
||||
.queryText(parseResp.getQueryText())
|
||||
.user(DataUtils.getUser())
|
||||
.parseInfo(parseResp.getSelectedParses().get(0))
|
||||
.parseInfo(parseResp.getCandidateParses().get(0))
|
||||
.saveAnswer(true)
|
||||
.build();
|
||||
|
||||
return queryService.performExecution(request);
|
||||
@@ -68,11 +69,12 @@ public class BaseQueryTest {
|
||||
|
||||
ExecuteQueryReq request = ExecuteQueryReq.builder()
|
||||
.queryId(parseResp.getQueryId())
|
||||
.parseId(parseResp.getSelectedParses().get(0).getId())
|
||||
.parseId(parseResp.getCandidateParses().get(0).getId())
|
||||
.chatId(parseResp.getChatId())
|
||||
.queryText(parseResp.getQueryText())
|
||||
.user(DataUtils.getUser())
|
||||
.parseInfo(parseResp.getSelectedParses().get(0))
|
||||
.parseInfo(parseResp.getCandidateParses().get(0))
|
||||
.saveAnswer(true)
|
||||
.build();
|
||||
|
||||
QueryResult result = queryService.performExecution(request);
|
||||
|
||||
@@ -57,8 +57,8 @@ public class MetricInterpretTest {
|
||||
.chatId(parseResp.getChatId())
|
||||
.queryId(parseResp.getQueryId())
|
||||
.queryText(parseResp.getQueryText())
|
||||
.parseInfo(parseResp.getSelectedParses().get(0))
|
||||
.parseId(parseResp.getSelectedParses().get(0).getId())
|
||||
.parseInfo(parseResp.getCandidateParses().get(0))
|
||||
.parseId(parseResp.getCandidateParses().get(0).getId())
|
||||
.build();
|
||||
QueryResult queryResult = queryService.performExecution(executeReq);
|
||||
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistantMessage());
|
||||
|
||||
@@ -58,8 +58,8 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
//agent only support METRIC_ENTITY, METRIC_FILTER
|
||||
MockConfiguration.mockAgent(agentService);
|
||||
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getAgent().getId());
|
||||
Assert.assertNotNull(parseResp.getSelectedParses());
|
||||
List<String> queryModes = parseResp.getSelectedParses().stream()
|
||||
Assert.assertNotNull(parseResp.getCandidateParses());
|
||||
List<String> queryModes = parseResp.getCandidateParses().stream()
|
||||
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
|
||||
Assert.assertTrue(queryModes.contains("METRIC_FILTER"));
|
||||
}
|
||||
@@ -88,7 +88,7 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
//agent only support METRIC_ENTITY, METRIC_FILTER
|
||||
MockConfiguration.mockAgent(agentService);
|
||||
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getAgent().getId());
|
||||
List<String> queryModes = parseResp.getSelectedParses().stream()
|
||||
List<String> queryModes = parseResp.getCandidateParses().stream()
|
||||
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
|
||||
Assert.assertTrue(queryModes.contains("METRIC_MODEL"));
|
||||
}
|
||||
@@ -123,7 +123,7 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
|
||||
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户名"));
|
||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||
|
||||
List<String> list = new ArrayList<>();
|
||||
|
||||
@@ -44,7 +44,7 @@ public class PluginRecognizeTest extends BasePluginTest {
|
||||
.chatId(parseResp.getChatId())
|
||||
.queryId(parseResp.getQueryId())
|
||||
.queryText(parseResp.getQueryText())
|
||||
.parseInfo(parseResp.getSelectedParses().get(0))
|
||||
.parseInfo(parseResp.getCandidateParses().get(0))
|
||||
.build();
|
||||
QueryResult queryResult = queryService.performExecution(executeReq);
|
||||
|
||||
@@ -69,7 +69,7 @@ public class PluginRecognizeTest extends BasePluginTest {
|
||||
.chatId(parseResp.getChatId())
|
||||
.queryId(parseResp.getQueryId())
|
||||
.queryText(parseResp.getQueryText())
|
||||
.parseInfo(parseResp.getSelectedParses().get(0))
|
||||
.parseInfo(parseResp.getCandidateParses().get(0))
|
||||
.build();
|
||||
QueryResult queryResult = queryService.performExecution(executeReq);
|
||||
|
||||
@@ -84,8 +84,8 @@ public class PluginRecognizeTest extends BasePluginTest {
|
||||
QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(1000, "alice最近的访问情况怎么样",
|
||||
DataUtils.getAgent().getId());
|
||||
ParseResp parseResp = queryService.performParsing(queryContextReq);
|
||||
Assert.assertTrue(parseResp.getSelectedParses() != null
|
||||
&& parseResp.getSelectedParses().size() > 0);
|
||||
Assert.assertTrue(parseResp.getCandidateParses() != null
|
||||
&& parseResp.getCandidateParses().size() > 0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -12,9 +12,6 @@ com.tencent.supersonic.chat.api.component.QueryProcessor=\
|
||||
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
||||
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter
|
||||
|
||||
com.tencent.supersonic.chat.query.QuerySelector=\
|
||||
com.tencent.supersonic.chat.query.HeuristicQuerySelector
|
||||
|
||||
com.tencent.supersonic.chat.application.query.DomainResolver=\
|
||||
com.tencent.supersonic.chat.application.query.HeuristicDomainResolver
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.ApplicationListener;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -24,6 +25,7 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
||||
@Autowired
|
||||
private EmbeddingUtils embeddingUtils;
|
||||
|
||||
@Async
|
||||
@Override
|
||||
public void onApplicationEvent(DataEvent event) {
|
||||
if (CollectionUtils.isEmpty(event.getDataItems())) {
|
||||
|
||||
Reference in New Issue
Block a user