Feature/Refactor querySelect to queryRanker and fix some errors in integration tests (#369)

* (fix) (chat) fix the context saving failure caused by the loss of default values caused by @builder

* (fix) (chat) fix date and metrics result in parse info in integration test

* (improvement) (chat) refactor querySelect to queryRanker

---------

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2023-11-12 22:47:58 +08:00
committed by GitHub
parent cb1ad94086
commit 731238de08
23 changed files with 127 additions and 214 deletions

View File

@@ -13,8 +13,8 @@ public class ExecuteQueryReq {
private Integer agentId;
private Integer chatId;
private String queryText;
private Long queryId = 7L;
private Integer parseId = 2;
private Long queryId;
private Integer parseId;
private SemanticParseInfo parseInfo;
private boolean saveAnswer = true;
private boolean saveAnswer;
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.Data;
import lombok.Getter;
@@ -19,8 +20,8 @@ public class ParseResp {
private String queryText;
private Long queryId;
private ParseState state;
private List<SemanticParseInfo> selectedParses;
private List<SemanticParseInfo> candidateParses;
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
private List<SemanticParseInfo> candidateParses = Lists.newArrayList();
private List<SolvedQueryRecallResp> similarSolvedQuery;
private ParseTimeCostDO parseTimeCost;
@@ -29,4 +30,11 @@ public class ParseResp {
PENDING,
FAILED
}
public List<SemanticParseInfo> getSelectedParses() {
selectedParses = Lists.newArrayList();
selectedParses.addAll(candidateParses);
candidateParses.clear();
return selectedParses;
}
}

View File

@@ -31,8 +31,7 @@ public interface ChatQueryRepository {
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses,
List<SemanticParseInfo> selectedParses);
List<SemanticParseInfo> candidateParses);
public ChatParseDO getParseInfo(Long questionId, int parseId);

View File

@@ -133,13 +133,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
@Override
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses,
List<SemanticParseInfo> selectedParses) {
List<SemanticParseInfo> candidateParses) {
Long queryId = createChatParse(parseResult, chatCtx, queryReq);
List<ChatParseDO> chatParseDOList = new ArrayList<>();
log.info("candidateParses size:{},selectedParses size:{}", candidateParses.size(), selectedParses.size());
getChatParseDO(chatCtx, queryReq, queryId, 0, 1, candidateParses, chatParseDOList);
getChatParseDO(chatCtx, queryReq, queryId, candidateParses.size(), 0, selectedParses, chatParseDOList);
getChatParseDO(chatCtx, queryReq, queryId, 0, candidateParses, chatParseDOList);
chatParseMapper.batchSaveParseInfo(chatParseDOList);
return chatParseDOList;
}
@@ -151,7 +148,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
}
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base, int isCandidate,
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId, int base,
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
for (int i = 0; i < parses.size(); i++) {
ChatParseDO chatParseDO = new ChatParseDO();
@@ -160,7 +157,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
chatParseDO.setQuestionId(queryId);
chatParseDO.setQueryText(queryReq.getQueryText());
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
chatParseDO.setIsCandidate(isCandidate);
chatParseDO.setIsCandidate(1);
if (i == 0) {
chatParseDO.setIsCandidate(0);
}
chatParseDO.setParseId(base + i + 1);
chatParseDO.setCreateTime(new java.util.Date());
chatParseDO.setUserName(queryReq.getUser().getName());

View File

@@ -1,84 +0,0 @@
package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.List;
import java.util.ArrayList;
import java.util.OptionalDouble;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class HeuristicQuerySelector implements QuerySelector {
@Override
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) {
log.debug("pick before [{}]", candidateQueries.stream().collect(Collectors.toList()));
List<SemanticQuery> selectedQueries = new ArrayList<>();
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
Double candidateThreshold = optimizationConfig.getCandidateThreshold();
if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) {
selectedQueries.addAll(candidateQueries);
} else {
OptionalDouble maxScoreOp = candidateQueries.stream().mapToDouble(
q -> q.getParseInfo().getScore()).max();
if (maxScoreOp.isPresent()) {
double maxScore = maxScoreOp.getAsDouble();
candidateQueries.stream().forEach(query -> {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!checkFullyInherited(query)
&& (maxScore - parseInfo.getScore()) / maxScore <= candidateThreshold
&& checkSatisfyOtherRules(query, candidateQueries)) {
selectedQueries.add(query);
}
log.info("candidate query (Model={}, queryMode={}) with score={}",
parseInfo.getModelName(), parseInfo.getQueryMode(), parseInfo.getScore());
});
}
}
log.debug("pick after [{}]", selectedQueries.stream().collect(Collectors.toList()));
return selectedQueries;
}
private boolean checkSatisfyOtherRules(SemanticQuery semanticQuery, List<SemanticQuery> candidateQueries) {
if (!semanticQuery.getQueryMode().equals(MetricModelQuery.QUERY_MODE)) {
return true;
}
for (SemanticQuery candidateQuery : candidateQueries) {
if (candidateQuery.getQueryMode().equals(MetricEntityQuery.QUERY_MODE)
&& semanticQuery.getParseInfo().getScore() == candidateQuery.getParseInfo().getScore()) {
return false;
}
}
return true;
}
private boolean checkFullyInherited(SemanticQuery query) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!(query instanceof RuleSemanticQuery)) {
return false;
}
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
if (!match.isInherited()) {
return false;
}
}
if (parseInfo.getDateInfo() != null && !parseInfo.getDateInfo().isInherited()) {
return false;
}
return true;
}
}

View File

@@ -0,0 +1,64 @@
package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.List;
import java.util.ArrayList;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
@Slf4j
@Component
public class QueryRanker {
@Value("${candidate.top.size:5}")
private int candidateTopSize;
public List<SemanticQuery> rank(List<SemanticQuery> candidateQueries) {
log.debug("pick before [{}]", candidateQueries);
if (CollectionUtils.isEmpty(candidateQueries)) {
return candidateQueries;
}
List<SemanticQuery> selectedQueries = new ArrayList<>();
if (candidateQueries.size() == 1) {
selectedQueries.addAll(candidateQueries);
} else {
selectedQueries = getTopCandidateQuery(candidateQueries);
}
log.debug("pick after [{}]", selectedQueries);
return selectedQueries;
}
public List<SemanticQuery> getTopCandidateQuery(List<SemanticQuery> semanticQueries) {
return semanticQueries.stream()
.filter(query -> !checkFullyInherited(query))
.sorted((o1, o2) -> {
if (o1.getParseInfo().getScore() < o2.getParseInfo().getScore()) {
return 1;
} else if (o1.getParseInfo().getScore() > o2.getParseInfo().getScore()) {
return -1;
}
return 0;
}).limit(candidateTopSize)
.collect(Collectors.toList());
}
private boolean checkFullyInherited(SemanticQuery query) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!(query instanceof RuleSemanticQuery)) {
return false;
}
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
if (!match.isInherited()) {
return false;
}
}
return parseInfo.getDateInfo() == null || parseInfo.getDateInfo().isInherited();
}
}

View File

@@ -1,14 +0,0 @@
package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import java.util.List;
/**
* This interface defines the contract for a selector that picks the most suitable semantic query.
**/
public interface QuerySelector {
List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq);
}

View File

@@ -19,7 +19,7 @@ public class EntityInfoParseResponder implements ParseResponder {
@Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext,
List<ChatParseDO> chatParseDOS) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
List<SemanticParseInfo> selectedParses = parseResp.getCandidateParses();
if (CollectionUtils.isEmpty(selectedParses)) {
return;
}

View File

@@ -24,7 +24,6 @@ public class SqlInfoParseResponder implements ParseResponder {
List<ChatParseDO> chatParseDOS) {
QueryReq queryReq = queryContext.getRequest();
Long startTime = System.currentTimeMillis();
addSqlInfo(queryReq, parseResp.getSelectedParses());
addSqlInfo(queryReq, parseResp.getCandidateParses());
parseResp.setParseTimeCost(new ParseTimeCostDO());
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime);
@@ -32,7 +31,6 @@ public class SqlInfoParseResponder implements ParseResponder {
Map<Integer, ChatParseDO> chatParseDOMap = chatParseDOS.stream()
.collect(Collectors.toMap(ChatParseDO::getParseId,
Function.identity(), (oldValue, newValue) -> newValue));
updateParseInfo(chatParseDOMap, parseResp.getSelectedParses());
updateParseInfo(chatParseDOMap, parseResp.getCandidateParses());
}
}

View File

@@ -1,16 +1,9 @@
package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import java.util.List;
public interface ParseInfoService {
List<SemanticParseInfo> getTopCandidateParseInfo(List<SemanticParseInfo> selectedParses,
List<SemanticParseInfo> candidateParses);
List<SemanticParseInfo> sortParseInfo(List<SemanticQuery> semanticQueries);
void updateParseInfo(SemanticParseInfo parseInfo);
}

View File

@@ -225,8 +225,7 @@ public class ChatServiceImpl implements ChatService {
@Override
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult) {
List<SemanticParseInfo> candidateParses = parseResult.getCandidateParses();
List<SemanticParseInfo> selectedParses = parseResult.getSelectedParses();
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses, selectedParses);
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses);
}
@Override

View File

@@ -2,7 +2,6 @@
package com.tencent.supersonic.chat.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
@@ -20,7 +19,6 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -31,45 +29,12 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Slf4j
@Service
public class ParserInfoServiceImpl implements ParseInfoService {
@Value("${candidate.top.size:5}")
private int candidateTopSize;
public List<SemanticParseInfo> getTopCandidateParseInfo(List<SemanticParseInfo> selectedParses,
List<SemanticParseInfo> candidateParses) {
if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) {
return candidateParses;
}
int selectParseSize = selectedParses.size();
Set<Double> selectParseScoreSet = selectedParses.stream()
.map(SemanticParseInfo::getScore).collect(Collectors.toSet());
int candidateParseSize = candidateTopSize - selectParseSize;
candidateParses = candidateParses.stream()
.filter(candidateParse -> !selectParseScoreSet.contains(candidateParse.getScore()))
.collect(Collectors.toList());
SemanticParseInfo semanticParseInfo = selectedParses.get(0);
Long modelId = semanticParseInfo.getModelId();
if (modelId == null || modelId <= 0) {
return candidateParses;
}
return candidateParses.stream()
.sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId)))
.limit(candidateParseSize)
.collect(Collectors.toList());
}
public List<SemanticParseInfo> sortParseInfo(List<SemanticQuery> semanticQueries) {
return semanticQueries.stream()
.map(SemanticQuery::getParseInfo)
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
.collect(Collectors.toList());
}
public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
@@ -81,9 +46,11 @@ public class ParserInfoServiceImpl implements ParseInfoService {
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(logicSql);
//set dataInfo
try {
if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) {
if (!CollectionUtils.isEmpty(expressions)) {
DateConf dateInfo = getDateInfo(expressions);
parseInfo.setDateInfo(dateInfo);
if (dateInfo != null && parseInfo.getDateInfo() == null) {
parseInfo.setDateInfo(dateInfo);
}
}
} catch (Exception e) {
log.error("set dateInfo error :", e);
@@ -103,10 +70,10 @@ public class ParserInfoServiceImpl implements ParseInfoService {
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
//cannot use metrics in sql to override parse info
//List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
//Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
//parseInfo.setMetrics(metrics);
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
parseInfo.setNativeQuery(false);
@@ -167,8 +134,8 @@ public class ParserInfoServiceImpl implements ParseInfoService {
List<FilterExpression> dateExpressions = filterExpressions.stream()
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) {
return new DateConf();
if (CollectionUtils.isEmpty(dateExpressions)) {
return null;
}
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateMode.BETWEEN);

View File

@@ -27,7 +27,7 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.query.QueryRanker;
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
@@ -105,13 +105,14 @@ public class QueryServiceImpl implements QueryService {
private SolvedQueryManager solvedQueryManager;
@Autowired
private ParseInfoService parseInfoService;
@Autowired
private QueryRanker queryRanker;
@Value("${time.threshold: 100}")
private Integer timeThreshold;
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
private QuerySelector querySelector = ComponentFactory.getQuerySelector();
private List<ParseResponder> parseResponders = ComponentFactory.getParseResponders();
private List<ExecuteResponder> executeResponders = ComponentFactory.getExecuteResponders();
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSqlCorrections();
@@ -157,18 +158,12 @@ public class QueryServiceImpl implements QueryService {
ParseResp parseResult;
List<ChatParseDO> chatParseDOS = Lists.newArrayList();
if (candidateQueries.size() > 0) {
List<SemanticQuery> selectedQueries = querySelector.select(candidateQueries, queryReq);
List<SemanticParseInfo> selectedParses = parseInfoService.sortParseInfo(selectedQueries);
List<SemanticParseInfo> candidateParses = parseInfoService.sortParseInfo(candidateQueries);
candidateParses = parseInfoService.getTopCandidateParseInfo(selectedParses, candidateParses);
candidateQueries.forEach(semanticQuery -> parseInfoService.updateParseInfo(semanticQuery.getParseInfo()));
parseResult = ParseResp.builder()
.chatId(queryReq.getChatId())
.queryText(queryReq.getQueryText())
.state(selectedParses.size() > 1 ? ParseResp.ParseState.PENDING : ParseResp.ParseState.COMPLETED)
.selectedParses(selectedParses)
.candidateParses(candidateParses)
candidateQueries = queryRanker.rank(candidateQueries);
List<SemanticParseInfo> candidateParses = candidateQueries.stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
candidateParses.forEach(parseInfo -> parseInfoService.updateParseInfo(parseInfo));
parseResult = ParseResp.builder().chatId(queryReq.getChatId()).queryText(queryReq.getQueryText())
.state(ParseResp.ParseState.COMPLETED).candidateParses(candidateParses)
.build();
chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult);
} else {

View File

@@ -10,7 +10,6 @@ import java.util.List;
import java.util.Objects;
import com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
import org.apache.commons.collections.CollectionUtils;
@@ -24,7 +23,6 @@ public class ComponentFactory {
private static SemanticInterpreter semanticInterpreter;
private static List<ParseResponder> parseResponders = new ArrayList<>();
private static List<ExecuteResponder> executeResponders = new ArrayList<>();
private static QuerySelector querySelector;
private static ModelResolver modelResolver;
public static List<SchemaMapper> getSchemaMappers() {
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers;
@@ -59,12 +57,6 @@ public class ComponentFactory {
semanticInterpreter = layer;
}
public static QuerySelector getQuerySelector() {
if (Objects.isNull(querySelector)) {
querySelector = init(QuerySelector.class);
}
return querySelector;
}
public static ModelResolver getModelResolver() {
if (Objects.isNull(modelResolver)) {

View File

@@ -20,9 +20,6 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter
com.tencent.supersonic.chat.query.QuerySelector=\
com.tencent.supersonic.chat.query.HeuristicQuerySelector
com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver=\
com.tencent.supersonic.chat.parser.llm.s2sql.HeuristicModelResolver

View File

@@ -68,12 +68,13 @@ public class ConfigureDemo implements ApplicationListener<ApplicationReadyEvent>
ExecuteQueryReq executeReq = ExecuteQueryReq.builder().build();
executeReq.setQueryId(parseResp.getQueryId());
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
executeReq.setParseId(parseResp.getCandidateParses().get(0).getId());
executeReq.setQueryText(queryRequest.getQueryText());
executeReq.setParseInfo(parseResp.getSelectedParses().get(0));
executeReq.setParseInfo(parseResp.getCandidateParses().get(0));
executeReq.setChatId(parseResp.getChatId());
executeReq.setUser(queryRequest.getUser());
executeReq.setAgentId(1);
executeReq.setSaveAnswer(true);
queryService.performExecution(executeReq);
}

View File

@@ -21,9 +21,6 @@ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter
com.tencent.supersonic.chat.query.QuerySelector=\
com.tencent.supersonic.chat.query.HeuristicQuerySelector
com.tencent.supersonic.chat.parser.llm.s2sql.ModelResolver=\
com.tencent.supersonic.chat.parser.llm.s2sql.HeuristicModelResolver

View File

@@ -53,11 +53,12 @@ public class BaseQueryTest {
ExecuteQueryReq request = ExecuteQueryReq.builder()
.queryId(parseResp.getQueryId())
.parseId(parseResp.getSelectedParses().get(0).getId())
.parseId(parseResp.getCandidateParses().get(0).getId())
.chatId(parseResp.getChatId())
.queryText(parseResp.getQueryText())
.user(DataUtils.getUser())
.parseInfo(parseResp.getSelectedParses().get(0))
.parseInfo(parseResp.getCandidateParses().get(0))
.saveAnswer(true)
.build();
return queryService.performExecution(request);
@@ -68,11 +69,12 @@ public class BaseQueryTest {
ExecuteQueryReq request = ExecuteQueryReq.builder()
.queryId(parseResp.getQueryId())
.parseId(parseResp.getSelectedParses().get(0).getId())
.parseId(parseResp.getCandidateParses().get(0).getId())
.chatId(parseResp.getChatId())
.queryText(parseResp.getQueryText())
.user(DataUtils.getUser())
.parseInfo(parseResp.getSelectedParses().get(0))
.parseInfo(parseResp.getCandidateParses().get(0))
.saveAnswer(true)
.build();
QueryResult result = queryService.performExecution(request);

View File

@@ -57,8 +57,8 @@ public class MetricInterpretTest {
.chatId(parseResp.getChatId())
.queryId(parseResp.getQueryId())
.queryText(parseResp.getQueryText())
.parseInfo(parseResp.getSelectedParses().get(0))
.parseId(parseResp.getSelectedParses().get(0).getId())
.parseInfo(parseResp.getCandidateParses().get(0))
.parseId(parseResp.getCandidateParses().get(0).getId())
.build();
QueryResult queryResult = queryService.performExecution(executeReq);
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistantMessage());

View File

@@ -58,8 +58,8 @@ public class MetricQueryTest extends BaseQueryTest {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockAgent(agentService);
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getAgent().getId());
Assert.assertNotNull(parseResp.getSelectedParses());
List<String> queryModes = parseResp.getSelectedParses().stream()
Assert.assertNotNull(parseResp.getCandidateParses());
List<String> queryModes = parseResp.getCandidateParses().stream()
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
Assert.assertTrue(queryModes.contains("METRIC_FILTER"));
}
@@ -88,7 +88,7 @@ public class MetricQueryTest extends BaseQueryTest {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockAgent(agentService);
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getAgent().getId());
List<String> queryModes = parseResp.getSelectedParses().stream()
List<String> queryModes = parseResp.getCandidateParses().stream()
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
Assert.assertTrue(queryModes.contains("METRIC_MODEL"));
}
@@ -123,7 +123,7 @@ public class MetricQueryTest extends BaseQueryTest {
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户名"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
List<String> list = new ArrayList<>();

View File

@@ -44,7 +44,7 @@ public class PluginRecognizeTest extends BasePluginTest {
.chatId(parseResp.getChatId())
.queryId(parseResp.getQueryId())
.queryText(parseResp.getQueryText())
.parseInfo(parseResp.getSelectedParses().get(0))
.parseInfo(parseResp.getCandidateParses().get(0))
.build();
QueryResult queryResult = queryService.performExecution(executeReq);
@@ -69,7 +69,7 @@ public class PluginRecognizeTest extends BasePluginTest {
.chatId(parseResp.getChatId())
.queryId(parseResp.getQueryId())
.queryText(parseResp.getQueryText())
.parseInfo(parseResp.getSelectedParses().get(0))
.parseInfo(parseResp.getCandidateParses().get(0))
.build();
QueryResult queryResult = queryService.performExecution(executeReq);
@@ -84,8 +84,8 @@ public class PluginRecognizeTest extends BasePluginTest {
QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(1000, "alice最近的访问情况怎么样",
DataUtils.getAgent().getId());
ParseResp parseResp = queryService.performParsing(queryContextReq);
Assert.assertTrue(parseResp.getSelectedParses() != null
&& parseResp.getSelectedParses().size() > 0);
Assert.assertTrue(parseResp.getCandidateParses() != null
&& parseResp.getCandidateParses().size() > 0);
}
}

View File

@@ -12,9 +12,6 @@ com.tencent.supersonic.chat.api.component.QueryProcessor=\
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter
com.tencent.supersonic.chat.query.QuerySelector=\
com.tencent.supersonic.chat.query.HeuristicQuerySelector
com.tencent.supersonic.chat.application.query.DomainResolver=\
com.tencent.supersonic.chat.application.query.HeuristicDomainResolver

View File

@@ -12,6 +12,7 @@ import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationListener;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@@ -24,6 +25,7 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
@Autowired
private EmbeddingUtils embeddingUtils;
@Async
@Override
public void onApplicationEvent(DataEvent event) {
if (CollectionUtils.isEmpty(event.getDataItems())) {