(improvement)(chat) Optimize the update parserInfo code and resolve compilation exceptions (#346)

This commit is contained in:
lexluo09
2023-11-09 17:35:38 +08:00
committed by GitHub
parent 6ad74bb206
commit 4e139c837a
8 changed files with 269 additions and 230 deletions

View File

@@ -20,7 +20,5 @@ public interface SemanticQuery {
SemanticParseInfo getParseInfo();
void updateParseInfo();
void setParseInfo(SemanticParseInfo parseInfo);
}

View File

@@ -1,28 +1,17 @@
package com.tencent.supersonic.chat.query;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
@@ -30,19 +19,13 @@ import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
@Slf4j
@ToString
@@ -87,167 +70,6 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
return QueryReqBuilder.buildStructReq(parseInfo);
}
public void updateParseInfo() {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
String logicSql = sqlInfo.getLogicSql();
if (StringUtils.isBlank(logicSql)) {
return;
}
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(logicSql);
//set dataInfo
try {
if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) {
DateConf dateInfo = getDateInfo(expressions);
parseInfo.setDateInfo(dateInfo);
}
} catch (Exception e) {
log.error("set dateInfo error :", e);
}
//set filter
try {
Map<String, SchemaElement> fieldNameToElement = getNameToElement(parseInfo.getModelId());
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
}
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getLogicSql()));
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getLogicSql())) {
parseInfo.setNativeQuery(false);
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getLogicSql());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
} else {
parseInfo.setNativeQuery(true);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getLogicSql());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(
getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions()));
}
}
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
&& allFields.contains(schemaElement.getName())
).collect(Collectors.toSet());
}
private List<String> getFieldsExceptDate(List<String> allFields) {
if (org.springframework.util.CollectionUtils.isEmpty(allFields)) {
return new ArrayList<>();
}
return allFields.stream()
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
.collect(Collectors.toList());
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FilterExpression> filterExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FilterExpression expression : filterExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue());
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
if (Objects.isNull(schemaElement)) {
continue;
}
dimensionFilter.setName(schemaElement.getName());
dimensionFilter.setBizName(schemaElement.getBizName());
dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
dimensionFilter.setOperator(operatorEnum);
dimensionFilter.setFunction(expression.getFunction());
result.add(dimensionFilter);
}
return result;
}
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
List<FilterExpression> dateExpressions = filterExpressions.stream()
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) {
return new DateConf();
}
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateMode.BETWEEN);
FilterExpression firstExpression = dateExpressions.get(0);
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
dateInfo.setDateMode(DateMode.BETWEEN);
return dateInfo;
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
}
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
}
}
return dateInfo;
}
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
}
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions();
List<SchemaElement> metrics = semanticSchema.getMetrics();
List<SchemaElement> allElements = Lists.newArrayList();
allElements.addAll(dimensions);
allElements.addAll(metrics);
//support alias
return allElements.stream()
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
.flatMap(schemaElement -> {
Set<Pair<String, SchemaElement>> result = new HashSet<>();
result.add(Pair.of(schemaElement.getName(), schemaElement));
List<String> aliasList = schemaElement.getAlias();
if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(Pair.of(alias, schemaElement));
}
}
return result.stream();
})
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
}
protected void convertBizNameToName(QueryStructReq queryStructReq) {
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
Map<String, String> bizNameToName = schemaService.getSemanticSchema()

View File

@@ -14,6 +14,7 @@ import java.util.OptionalDouble;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@@ -22,6 +23,7 @@ public class HeuristicQuerySelector implements QuerySelector {
@Override
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries, QueryReq queryReq) {
log.debug("pick before [{}]", candidateQueries.stream().collect(Collectors.toList()));
List<SemanticQuery> selectedQueries = new ArrayList<>();
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);
Double candidateThreshold = optimizationConfig.getCandidateThreshold();
@@ -45,6 +47,7 @@ public class HeuristicQuerySelector implements QuerySelector {
});
}
}
log.debug("pick after [{}]", selectedQueries.stream().collect(Collectors.toList()));
return selectedQueries;
}

View File

@@ -3,17 +3,16 @@ package com.tencent.supersonic.chat.service;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import java.util.List;
public interface ChatService {
@@ -49,10 +48,7 @@ public interface ChatService {
void addQuery(QueryResult queryResult, ChatContext chatCtx);
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses,
List<SemanticParseInfo> selectedParses);
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult);
void updateChatParse(List<ChatParseDO> chatParseDOS);

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import java.util.List;
public interface ParseInfoService {
List<SemanticParseInfo> getTopCandidateParseInfo(List<SemanticParseInfo> selectedParses,
List<SemanticParseInfo> candidateParses);
List<SemanticParseInfo> sortParseInfo(List<SemanticQuery> semanticQueries);
void updateParseInfo(SemanticParseInfo parseInfo);
}

View File

@@ -223,10 +223,9 @@ public class ChatServiceImpl implements ChatService {
}
@Override
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses,
List<SemanticParseInfo> selectedParses) {
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult) {
List<SemanticParseInfo> candidateParses = parseResult.getCandidateParses();
List<SemanticParseInfo> selectedParses = parseResult.getSelectedParses();
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses, selectedParses);
}

View File

@@ -0,0 +1,234 @@
package com.tencent.supersonic.chat.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.service.ParseInfoService;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
@Slf4j
@Service
public class ParserInfoServiceImpl implements ParseInfoService {
@Value("${candidate.top.size:5}")
private int candidateTopSize;
public List<SemanticParseInfo> getTopCandidateParseInfo(List<SemanticParseInfo> selectedParses,
List<SemanticParseInfo> candidateParses) {
if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) {
return candidateParses;
}
int selectParseSize = selectedParses.size();
Set<Double> selectParseScoreSet = selectedParses.stream()
.map(SemanticParseInfo::getScore).collect(Collectors.toSet());
int candidateParseSize = candidateTopSize - selectParseSize;
candidateParses = candidateParses.stream()
.filter(candidateParse -> !selectParseScoreSet.contains(candidateParse.getScore()))
.collect(Collectors.toList());
SemanticParseInfo semanticParseInfo = selectedParses.get(0);
Long modelId = semanticParseInfo.getModelId();
if (modelId == null || modelId <= 0) {
return candidateParses;
}
return candidateParses.stream()
.sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId)))
.limit(candidateParseSize)
.collect(Collectors.toList());
}
public List<SemanticParseInfo> sortParseInfo(List<SemanticQuery> semanticQueries) {
return semanticQueries.stream()
.map(SemanticQuery::getParseInfo)
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
.collect(Collectors.toList());
}
public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
String logicSql = sqlInfo.getLogicSql();
if (StringUtils.isBlank(logicSql)) {
return;
}
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(logicSql);
//set dataInfo
try {
if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) {
DateConf dateInfo = getDateInfo(expressions);
parseInfo.setDateInfo(dateInfo);
}
} catch (Exception e) {
log.error("set dateInfo error :", e);
}
//set filter
try {
Map<String, SchemaElement> fieldNameToElement = getNameToElement(parseInfo.getModelId());
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
}
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getLogicSql()));
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getLogicSql())) {
parseInfo.setNativeQuery(false);
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getLogicSql());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
} else {
parseInfo.setNativeQuery(true);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getLogicSql());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(
getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions()));
}
}
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
&& allFields.contains(schemaElement.getName())
).collect(Collectors.toSet());
}
private List<String> getFieldsExceptDate(List<String> allFields) {
if (org.springframework.util.CollectionUtils.isEmpty(allFields)) {
return new ArrayList<>();
}
return allFields.stream()
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
.collect(Collectors.toList());
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FilterExpression> filterExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FilterExpression expression : filterExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue());
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
if (Objects.isNull(schemaElement)) {
continue;
}
dimensionFilter.setName(schemaElement.getName());
dimensionFilter.setBizName(schemaElement.getBizName());
dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
dimensionFilter.setOperator(operatorEnum);
dimensionFilter.setFunction(expression.getFunction());
result.add(dimensionFilter);
}
return result;
}
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
List<FilterExpression> dateExpressions = filterExpressions.stream()
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) {
return new DateConf();
}
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateMode.BETWEEN);
FilterExpression firstExpression = dateExpressions.get(0);
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
dateInfo.setDateMode(DateMode.BETWEEN);
return dateInfo;
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
}
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
}
}
return dateInfo;
}
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
}
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions();
List<SchemaElement> metrics = semanticSchema.getMetrics();
List<SchemaElement> allElements = Lists.newArrayList();
allElements.addAll(dimensions);
allElements.addAll(metrics);
//support alias
return allElements.stream()
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
.flatMap(schemaElement -> {
Set<Pair<String, SchemaElement>> result = new HashSet<>();
result.add(Pair.of(schemaElement.getName(), schemaElement));
List<String> aliasList = schemaElement.getAlias();
if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(Pair.of(alias, schemaElement));
}
}
return result.stream();
})
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
}
}

View File

@@ -31,6 +31,7 @@ import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery;
import com.tencent.supersonic.chat.responder.execute.ExecuteResponder;
import com.tencent.supersonic.chat.responder.parse.ParseResponder;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.ParseInfoService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.service.StatisticsService;
@@ -58,7 +59,6 @@ import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -102,6 +102,8 @@ public class QueryServiceImpl implements QueryService {
private StatisticsService statisticsService;
@Autowired
private SolvedQueryManager solvedQueryManager;
@Autowired
private ParseInfoService parseInfoService;
@Value("${time.threshold: 100}")
private Integer timeThreshold;
@@ -148,7 +150,6 @@ public class QueryServiceImpl implements QueryService {
semanticCorrectors.stream().forEach(correction -> {
correction.correct(queryReq, semanticQuery.getParseInfo());
});
semanticQuery.updateParseInfo();
}
}
@@ -156,15 +157,14 @@ public class QueryServiceImpl implements QueryService {
ParseResp parseResult;
List<ChatParseDO> chatParseDOS = Lists.newArrayList();
if (candidateQueries.size() > 0) {
log.debug("pick before [{}]", candidateQueries.stream().collect(
Collectors.toList()));
List<SemanticQuery> selectedQueries = querySelector.select(candidateQueries, queryReq);
log.debug("pick after [{}]", selectedQueries.stream().collect(
Collectors.toList()));
List<SemanticParseInfo> selectedParses = convertParseInfo(selectedQueries);
List<SemanticParseInfo> candidateParses = convertParseInfo(candidateQueries);
candidateParses = getTop5CandidateParseInfo(selectedParses, candidateParses);
candidateQueries.forEach(semanticQuery -> parseInfoService.updateParseInfo(semanticQuery.getParseInfo()));
List<SemanticParseInfo> selectedParses = parseInfoService.sortParseInfo(selectedQueries);
List<SemanticParseInfo> candidateParses = parseInfoService.sortParseInfo(candidateQueries);
candidateParses = parseInfoService.getTopCandidateParseInfo(selectedParses, candidateParses);
parseResult = ParseResp.builder()
.chatId(queryReq.getChatId())
.queryText(queryReq.getQueryText())
@@ -172,7 +172,7 @@ public class QueryServiceImpl implements QueryService {
.selectedParses(selectedParses)
.candidateParses(candidateParses)
.build();
chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult, candidateParses, selectedParses);
chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult);
} else {
parseResult = ParseResp.builder()
.chatId(queryReq.getChatId())
@@ -198,35 +198,6 @@ public class QueryServiceImpl implements QueryService {
return parseResult;
}
private List<SemanticParseInfo> convertParseInfo(List<SemanticQuery> semanticQueries) {
return semanticQueries.stream()
.map(SemanticQuery::getParseInfo)
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
.collect(Collectors.toList());
}
private List<SemanticParseInfo> getTop5CandidateParseInfo(List<SemanticParseInfo> selectedParses,
List<SemanticParseInfo> candidateParses) {
if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) {
return candidateParses;
}
int selectParseSize = selectedParses.size();
Set<Double> selectParseScoreSet = selectedParses.stream()
.map(SemanticParseInfo::getScore).collect(Collectors.toSet());
int candidateParseSize = 5 - selectParseSize;
candidateParses = candidateParses.stream()
.filter(candidateParse -> !selectParseScoreSet.contains(candidateParse.getScore()))
.collect(Collectors.toList());
SemanticParseInfo semanticParseInfo = selectedParses.get(0);
Long modelId = semanticParseInfo.getModelId();
if (modelId == null || modelId <= 0) {
return candidateParses;
}
return candidateParses.stream()
.sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId)))
.limit(candidateParseSize)
.collect(Collectors.toList());
}
@Override
@TimeCost