Feature/Refactor querySelect to queryRanker and fix some errors in integration tests (#369)

* (fix) (chat) fix the context saving failure caused by the loss of default values caused by @builder

* (fix) (chat) fix date and metrics result in parse info in integration test

* (improvement) (chat) refactor querySelect to queryRanker

---------

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2023-11-12 22:47:58 +08:00
committed by GitHub
parent cb1ad94086
commit 731238de08
23 changed files with 127 additions and 214 deletions

View File

@@ -31,8 +31,7 @@ public interface ChatQueryRepository {
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses,
List<SemanticParseInfo> selectedParses);
List<SemanticParseInfo> candidateParses);
public ChatParseDO getParseInfo(Long questionId, int parseId);

View File

@@ -133,13 +133,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
@Override
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses,
List<SemanticParseInfo> selectedParses) {
List<SemanticParseInfo> candidateParses) {
Long queryId = createChatParse(parseResult, chatCtx, queryReq);
List<ChatParseDO> chatParseDOList = new ArrayList<>();
log.info("candidateParses size:{},selectedParses size:{}", candidateParses.size(), selectedParses.size());
getChatParseDO(chatCtx, queryReq, queryId, 0, 1, candidateParses, chatParseDOList);
getChatParseDO(chatCtx, queryReq, queryId, candidateParses.size(), 0, selectedParses, chatParseDOList);
getChatParseDO(chatCtx, queryReq, queryId, 0, candidateParses, chatParseDOList);
chatParseMapper.batchSaveParseInfo(chatParseDOList);
return chatParseDOList;
}
@@ -151,7 +148,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
}
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base, int isCandidate,
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base,
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
for (int i = 0; i < parses.size(); i++) {
ChatParseDO chatParseDO = new ChatParseDO();
@@ -160,7 +157,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
chatParseDO.setQuestionId(queryId);
chatParseDO.setQueryText(queryReq.getQueryText());
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
chatParseDO.setIsCandidate(isCandidate);
chatParseDO.setIsCandidate(1);
if (i == 0) {
chatParseDO.setIsCandidate(0);
}
chatParseDO.setParseId(base + i + 1);
chatParseDO.setCreateTime(new java.util.Date());
chatParseDO.setUserName(queryReq.getUser().getName());

View File

@@ -1,84 +0,0 @@
package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.List;
import java.util.ArrayList;
import java.util.OptionalDouble;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class HeuristicQuerySelector implements QuerySelector {
@Override
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) {
log.debug("pick before [{}]", candidateQueries.stream().collect(Collectors.toList()));
List<SemanticQuery> selectedQueries = new ArrayList<>();
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
Double candidateThreshold = optimizationConfig.getCandidateThreshold();
if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) {
selectedQueries.addAll(candidateQueries);
} else {
OptionalDouble maxScoreOp = candidateQueries.stream().mapToDouble(
q -> q.getParseInfo().getScore()).max();
if (maxScoreOp.isPresent()) {
double maxScore = maxScoreOp.getAsDouble();
candidateQueries.stream().forEach(query -> {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!checkFullyInherited(query)
&& (maxScore - parseInfo.getScore()) / maxScore <= candidateThreshold
&& checkSatisfyOtherRules(query, candidateQueries)) {
selectedQueries.add(query);
}
log.info("candidate query (Model={}, queryMode={}) with score={}",
parseInfo.getModelName(), parseInfo.getQueryMode(), parseInfo.getScore());
});
}
}
log.debug("pick after [{}]", selectedQueries.stream().collect(Collectors.toList()));
return selectedQueries;
}
private boolean checkSatisfyOtherRules(SemanticQuery semanticQuery, List<SemanticQuery> candidateQueries) {
if (!semanticQuery.getQueryMode().equals(MetricModelQuery.QUERY_MODE)) {
return true;
}
for (SemanticQuery candidateQuery : candidateQueries) {
if (candidateQuery.getQueryMode().equals(MetricEntityQuery.QUERY_MODE)
&& semanticQuery.getParseInfo().getScore() == candidateQuery.getParseInfo().getScore()) {
return false;
}
}
return true;
}
private boolean checkFullyInherited(SemanticQuery query) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!(query instanceof RuleSemanticQuery)) {
return false;
}
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
if (!match.isInherited()) {
return false;
}
}
if (parseInfo.getDateInfo() != null && !parseInfo.getDateInfo().isInherited()) {
return false;
}
return true;
}
}

View File

@@ -0,0 +1,64 @@
package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.List;
import java.util.ArrayList;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
@Slf4j
@Component
public class QueryRanker {
@Value("${candidate.top.size:5}")
private int candidateTopSize;
public List<SemanticQuery> rank(List<SemanticQuery> candidateQueries) {
log.debug("pick before [{}]", candidateQueries);
if (CollectionUtils.isEmpty(candidateQueries)) {
return candidateQueries;
}
List<SemanticQuery> selectedQueries = new ArrayList<>();
if (candidateQueries.size() == 1) {
selectedQueries.addAll(candidateQueries);
} else {
selectedQueries = getTopCandidateQuery(candidateQueries);
}
log.debug("pick after [{}]", selectedQueries);
return selectedQueries;
}
public List<SemanticQuery> getTopCandidateQuery(List<SemanticQuery> semanticQueries) {
return semanticQueries.stream()
.filter(query -> !checkFullyInherited(query))
.sorted((o1, o2) -> {
if (o1.getParseInfo().getScore() < o2.getParseInfo().getScore()) {
return 1;
} else if (o1.getParseInfo().getScore() > o2.getParseInfo().getScore()) {
return -1;
}
return 0;
}).limit(candidateTopSize)
.collect(Collectors.toList());
}
private boolean checkFullyInherited(SemanticQuery query) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!(query instanceof RuleSemanticQuery)) {
return false;
}
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
if (!match.isInherited()) {
return false;
}
}
return parseInfo.getDateInfo() == null || parseInfo.getDateInfo().isInherited();
}
}

View File

@@ -1,14 +0,0 @@
package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import java.util.List;
/**
* This interface defines the contract for a selector that picks the most suitable semantic query.
**/
public interface QuerySelector {
List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq);
}

View File

@@ -19,7 +19,7 @@ public class EntityInfoParseResponder implements ParseResponder {
@Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext,
List<ChatParseDO> chatParseDOS) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
List<SemanticParseInfo> selectedParses = parseResp.getCandidateParses();
if (CollectionUtils.isEmpty(selectedParses)) {
return;
}

View File

@@ -24,7 +24,6 @@ public class SqlInfoParseResponder implements ParseResponder {
List<ChatParseDO> chatParseDOS) {
QueryReq queryReq = queryContext.getRequest();
Long startTime = System.currentTimeMillis();
addSqlInfo(queryReq, parseResp.getSelectedParses());
addSqlInfo(queryReq, parseResp.getCandidateParses());
parseResp.setParseTimeCost(new ParseTimeCostDO());
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime);
@@ -32,7 +31,6 @@ public class SqlInfoParseResponder implements ParseResponder {
Map<Integer, ChatParseDO> chatParseDOMap = chatParseDOS.stream()
.collect(Collectors.toMap(ChatParseDO::getParseId,
Function.identity(), (oldValue, newValue) -> newValue));
updateParseInfo(chatParseDOMap, parseResp.getSelectedParses());
updateParseInfo(chatParseDOMap, parseResp.getCandidateParses());
}
}

View File

@@ -1,16 +1,9 @@
package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import java.util.List;
public interface ParseInfoService {
List<SemanticParseInfo> getTopCandidateParseInfo(List<SemanticParseInfo> selectedParses,
List<SemanticParseInfo> candidateParses);
List<SemanticParseInfo> sortParseInfo(List<SemanticQuery> semanticQueries);
void updateParseInfo(SemanticParseInfo parseInfo);
}

View File

@@ -225,8 +225,7 @@ public class ChatServiceImpl implements ChatService {
@Override
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult) {
List<SemanticParseInfo> candidateParses = parseResult.getCandidateParses();
List<SemanticParseInfo> selectedParses = parseResult.getSelectedParses();
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses, selectedParses);
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses);
}
@Override

View File

@@ -2,7 +2,6 @@
package com.tencent.supersonic.chat.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
@@ -20,7 +19,6 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -31,45 +29,12 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Slf4j
@Service
public class ParserInfoServiceImpl implements ParseInfoService {
@Value("${candidate.top.size:5}")
private int candidateTopSize;
public List<SemanticParseInfo> getTopCandidateParseInfo(List<SemanticParseInfo> selectedParses,
List<SemanticParseInfo> candidateParses) {
if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) {
return candidateParses;
}
int selectParseSize = selectedParses.size();
Set<Double> selectParseScoreSet = selectedParses.stream()
.map(SemanticParseInfo::getScore).collect(Collectors.toSet());
int candidateParseSize = candidateTopSize - selectParseSize;
candidateParses = candidateParses.stream()
.filter(candidateParse -> !selectParseScoreSet.contains(candidateParse.getScore()))
.collect(Collectors.toList());
SemanticParseInfo semanticParseInfo = selectedParses.get(0);
Long modelId = semanticParseInfo.getModelId();
if (modelId == null || modelId <= 0) {
return candidateParses;
}
return candidateParses.stream()
.sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId)))
.limit(candidateParseSize)
.collect(Collectors.toList());
}
public List<SemanticParseInfo> sortParseInfo(List<SemanticQuery> semanticQueries) {
return semanticQueries.stream()
.map(SemanticQuery::getParseInfo)
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
.collect(Collectors.toList());
}
public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
@@ -81,9 +46,11 @@ public class ParserInfoServiceImpl implements ParseInfoService {
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(logicSql);
//set dataInfo
try {
if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) {
if (!CollectionUtils.isEmpty(expressions)) {
DateConf dateInfo = getDateInfo(expressions);
parseInfo.setDateInfo(dateInfo);
if (dateInfo != null && parseInfo.getDateInfo() == null) {
parseInfo.setDateInfo(dateInfo);
}
}
} catch (Exception e) {
log.error("set dateInfo error :", e);
@@ -103,10 +70,10 @@ public class ParserInfoServiceImpl implements ParseInfoService {
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
//cannot use metrics in sql to override parse info
//List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
//Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
//parseInfo.setMetrics(metrics);
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
parseInfo.setNativeQuery(false);
@@ -167,8 +134,8 @@ public class ParserInfoServiceImpl implements ParseInfoService {
List<FilterExpression> dateExpressions = filterExpressions.stream()
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) {
return new DateConf();
if (CollectionUtils.isEmpty(dateExpressions)) {
return null;
}
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateMode.BETWEEN);

View File

@@ -27,7 +27,7 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.query.QueryRanker;
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
@@ -105,13 +105,14 @@ public class QueryServiceImpl implements QueryService {
private SolvedQueryManager solvedQueryManager;
@Autowired
private ParseInfoService parseInfoService;
@Autowired
private QueryRanker queryRanker;
@Value("${time.threshold: 100}")
private Integer timeThreshold;
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
private QuerySelector querySelector = ComponentFactory.getQuerySelector();
private List<ParseResponder> parseResponders = ComponentFactory.getParseResponders();
private List<ExecuteResponder> executeResponders = ComponentFactory.getExecuteResponders();
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSqlCorrections();
@@ -157,18 +158,12 @@ public class QueryServiceImpl implements QueryService {
ParseResp parseResult;
List<ChatParseDO> chatParseDOS = Lists.newArrayList();
if (candidateQueries.size() > 0) {
List<SemanticQuery> selectedQueries = querySelector.select(candidateQueries, queryReq);
List<SemanticParseInfo> selectedParses = parseInfoService.sortParseInfo(selectedQueries);
List<SemanticParseInfo> candidateParses = parseInfoService.sortParseInfo(candidateQueries);
candidateParses = parseInfoService.getTopCandidateParseInfo(selectedParses, candidateParses);
candidateQueries.forEach(semanticQuery -> parseInfoService.updateParseInfo(semanticQuery.getParseInfo()));
parseResult = ParseResp.builder()
.chatId(queryReq.getChatId())
.queryText(queryReq.getQueryText())
.state(selectedParses.size() > 1 ? ParseResp.ParseState.PENDING : ParseResp.ParseState.COMPLETED)
.selectedParses(selectedParses)
.candidateParses(candidateParses)
candidateQueries = queryRanker.rank(candidateQueries);
List<SemanticParseInfo> candidateParses = candidateQueries.stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
candidateParses.forEach(parseInfo -> parseInfoService.updateParseInfo(parseInfo));
parseResult = ParseResp.builder().chatId(queryReq.getChatId()).queryText(queryReq.getQueryText())
.state(ParseResp.ParseState.COMPLETED).candidateParses(candidateParses)
.build();
chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult);
} else {

View File

@@ -10,7 +10,6 @@ import java.util.List;
import java.util.Objects;
import com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
import org.apache.commons.collections.CollectionUtils;
@@ -24,7 +23,6 @@ public class ComponentFactory {
private static SemanticInterpreter semanticInterpreter;
private static List<ParseResponder> parseResponders = new ArrayList<>();
private static List<ExecuteResponder> executeResponders = new ArrayList<>();
private static QuerySelector querySelector;
private static ModelResolver modelResolver;
public static List<SchemaMapper> getSchemaMappers() {
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers;
@@ -59,12 +57,6 @@ public class ComponentFactory {
semanticInterpreter = layer;
}
public static QuerySelector getQuerySelector() {
if (Objects.isNull(querySelector)) {
querySelector = init(QuerySelector.class);
}
return querySelector;
}
public static ModelResolver getModelResolver() {
if (Objects.isNull(modelResolver)) {