diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticQuery.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticQuery.java index 8622b0cda..6a35ae7bd 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticQuery.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticQuery.java @@ -20,7 +20,5 @@ public interface SemanticQuery { SemanticParseInfo getParseInfo(); - void updateParseInfo(); - void setParseInfo(SemanticParseInfo parseInfo); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/BaseSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/BaseSemanticQuery.java index 3b4a8b98e..7e4206b48 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/BaseSemanticQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/BaseSemanticQuery.java @@ -1,28 +1,17 @@ package com.tencent.supersonic.chat.query; -import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.component.SemanticInterpreter; import com.tencent.supersonic.chat.api.component.SemanticQuery; -import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticSchema; -import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.QueryReqBuilder; import com.tencent.supersonic.common.pojo.Aggregator; -import com.tencent.supersonic.common.pojo.DateConf; -import com.tencent.supersonic.common.pojo.DateConf.DateMode; import com.tencent.supersonic.common.pojo.Filter; import com.tencent.supersonic.common.pojo.Order; -import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; -import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum; import com.tencent.supersonic.semantic.api.model.response.ExplainResp; @@ -30,19 +19,13 @@ import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.QueryS2QLReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; import java.util.stream.Collectors; import lombok.ToString; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.lang3.StringUtils; -import org.apache.commons.lang3.tuple.Pair; @Slf4j @ToString @@ -87,167 +70,6 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable { return QueryReqBuilder.buildStructReq(parseInfo); } - public void updateParseInfo() { - SqlInfo sqlInfo = parseInfo.getSqlInfo(); - String logicSql = sqlInfo.getLogicSql(); - if (StringUtils.isBlank(logicSql)) { - return; - } - - List expressions = SqlParserSelectHelper.getFilterExpression(logicSql); - //set dataInfo - try { - if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) { - DateConf dateInfo = getDateInfo(expressions); - parseInfo.setDateInfo(dateInfo); - } - } catch (Exception e) { - log.error("set dateInfo error :", e); - } - - //set filter - try { - Map fieldNameToElement = getNameToElement(parseInfo.getModelId()); - List result = getDimensionFilter(fieldNameToElement, expressions); - parseInfo.getDimensionFilters().addAll(result); - } catch (Exception e) { - log.error("set dimensionFilter error :", e); - } - - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - - if (Objects.isNull(semanticSchema)) { - return; - } - List allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getLogicSql())); - - Set metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics()); - parseInfo.setMetrics(metrics); - - if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getLogicSql())) { - parseInfo.setNativeQuery(false); - List groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getLogicSql()); - List groupByDimensions = getFieldsExceptDate(groupByFields); - parseInfo.setDimensions( - getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions())); - } else { - parseInfo.setNativeQuery(true); - List selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getLogicSql()); - List selectDimensions = getFieldsExceptDate(selectFields); - parseInfo.setDimensions( - getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions())); - } - } - - - private Set getElements(Long modelId, List allFields, List elements) { - return elements.stream() - .filter(schemaElement -> modelId.equals(schemaElement.getModel()) - && allFields.contains(schemaElement.getName()) - ).collect(Collectors.toSet()); - } - - private List getFieldsExceptDate(List allFields) { - if (org.springframework.util.CollectionUtils.isEmpty(allFields)) { - return new ArrayList<>(); - } - return allFields.stream() - .filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry)) - .collect(Collectors.toList()); - } - - - private List getDimensionFilter(Map fieldNameToElement, - List filterExpressions) { - List result = Lists.newArrayList(); - for (FilterExpression expression : filterExpressions) { - QueryFilter dimensionFilter = new QueryFilter(); - dimensionFilter.setValue(expression.getFieldValue()); - SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName()); - if (Objects.isNull(schemaElement)) { - continue; - } - dimensionFilter.setName(schemaElement.getName()); - dimensionFilter.setBizName(schemaElement.getBizName()); - dimensionFilter.setElementID(schemaElement.getId()); - - FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator()); - dimensionFilter.setOperator(operatorEnum); - dimensionFilter.setFunction(expression.getFunction()); - result.add(dimensionFilter); - } - return result; - } - - private DateConf getDateInfo(List filterExpressions) { - List dateExpressions = filterExpressions.stream() - .filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName())) - .collect(Collectors.toList()); - if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) { - return new DateConf(); - } - DateConf dateInfo = new DateConf(); - dateInfo.setDateMode(DateMode.BETWEEN); - FilterExpression firstExpression = dateExpressions.get(0); - - FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator()); - if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) { - dateInfo.setStartDate(firstExpression.getFieldValue().toString()); - dateInfo.setEndDate(firstExpression.getFieldValue().toString()); - dateInfo.setDateMode(DateMode.BETWEEN); - return dateInfo; - } - if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN, - FilterOperatorEnum.GREATER_THAN_EQUALS)) { - dateInfo.setStartDate(firstExpression.getFieldValue().toString()); - if (hasSecondDate(dateExpressions)) { - dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString()); - } - } - if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN, - FilterOperatorEnum.MINOR_THAN_EQUALS)) { - dateInfo.setEndDate(firstExpression.getFieldValue().toString()); - if (hasSecondDate(dateExpressions)) { - dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString()); - } - } - return dateInfo; - } - - private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator, - FilterOperatorEnum... operatorEnums) { - return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue())); - } - - private boolean hasSecondDate(List dateExpressions) { - return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue()); - } - - protected Map getNameToElement(Long modelId) { - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - List dimensions = semanticSchema.getDimensions(); - List metrics = semanticSchema.getMetrics(); - - List allElements = Lists.newArrayList(); - allElements.addAll(dimensions); - allElements.addAll(metrics); - //support alias - return allElements.stream() - .filter(schemaElement -> schemaElement.getModel().equals(modelId)) - .flatMap(schemaElement -> { - Set> result = new HashSet<>(); - result.add(Pair.of(schemaElement.getName(), schemaElement)); - List aliasList = schemaElement.getAlias(); - if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) { - for (String alias : aliasList) { - result.add(Pair.of(alias, schemaElement)); - } - } - return result.stream(); - }) - .collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2)); - } - protected void convertBizNameToName(QueryStructReq queryStructReq) { SchemaService schemaService = ContextUtils.getBean(SchemaService.class); Map bizNameToName = schemaService.getSemanticSchema() diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/HeuristicQuerySelector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/HeuristicQuerySelector.java index 38005f126..0b75e2709 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/HeuristicQuerySelector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/HeuristicQuerySelector.java @@ -14,6 +14,7 @@ import java.util.OptionalDouble; import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery; import com.tencent.supersonic.common.util.ContextUtils; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; @@ -22,6 +23,7 @@ public class HeuristicQuerySelector implements QuerySelector { @Override public List select(List candidateQueries, QueryReq queryReq) { + log.debug("pick before [{}]", candidateQueries.stream().collect(Collectors.toList())); List selectedQueries = new ArrayList<>(); OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); Double candidateThreshold = optimizationConfig.getCandidateThreshold(); @@ -45,6 +47,7 @@ public class HeuristicQuerySelector implements QuerySelector { }); } } + log.debug("pick after [{}]", selectedQueries.stream().collect(Collectors.toList())); return selectedQueries; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/ChatService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/ChatService.java index 4809cc907..2957a8118 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/ChatService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/ChatService.java @@ -3,17 +3,16 @@ package com.tencent.supersonic.chat.service; import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.ChatContext; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; +import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp; import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; import com.tencent.supersonic.chat.persistence.dataobject.ChatDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO; -import com.tencent.supersonic.chat.api.pojo.response.QueryResp; -import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import java.util.List; public interface ChatService { @@ -49,10 +48,7 @@ public interface ChatService { void addQuery(QueryResult queryResult, ChatContext chatCtx); - List batchAddParse(ChatContext chatCtx, QueryReq queryReq, - ParseResp parseResult, - List candidateParses, - List selectedParses); + List batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult); void updateChatParse(List chatParseDOS); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/ParseInfoService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/ParseInfoService.java new file mode 100644 index 000000000..4df39777a --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/ParseInfoService.java @@ -0,0 +1,16 @@ +package com.tencent.supersonic.chat.service; + +import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import java.util.List; + +public interface ParseInfoService { + + List getTopCandidateParseInfo(List selectedParses, + List candidateParses); + + List sortParseInfo(List semanticQueries); + + void updateParseInfo(SemanticParseInfo parseInfo); + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java index 455fff0b7..d51b69956 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ChatServiceImpl.java @@ -223,10 +223,9 @@ public class ChatServiceImpl implements ChatService { } @Override - public List batchAddParse(ChatContext chatCtx, QueryReq queryReq, - ParseResp parseResult, - List candidateParses, - List selectedParses) { + public List batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult) { + List candidateParses = parseResult.getCandidateParses(); + List selectedParses = parseResult.getSelectedParses(); return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses, selectedParses); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java new file mode 100644 index 000000000..13f4eed2e --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java @@ -0,0 +1,234 @@ + +package com.tencent.supersonic.chat.service.impl; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; +import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; +import com.tencent.supersonic.chat.service.ParseInfoService; +import com.tencent.supersonic.common.pojo.DateConf; +import com.tencent.supersonic.common.pojo.DateConf.DateMode; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; +import com.tencent.supersonic.knowledge.service.SchemaService; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; + +@Slf4j +@Service +public class ParserInfoServiceImpl implements ParseInfoService { + + @Value("${candidate.top.size:5}") + private int candidateTopSize; + + public List getTopCandidateParseInfo(List selectedParses, + List candidateParses) { + if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) { + return candidateParses; + } + int selectParseSize = selectedParses.size(); + Set selectParseScoreSet = selectedParses.stream() + .map(SemanticParseInfo::getScore).collect(Collectors.toSet()); + int candidateParseSize = candidateTopSize - selectParseSize; + candidateParses = candidateParses.stream() + .filter(candidateParse -> !selectParseScoreSet.contains(candidateParse.getScore())) + .collect(Collectors.toList()); + SemanticParseInfo semanticParseInfo = selectedParses.get(0); + Long modelId = semanticParseInfo.getModelId(); + if (modelId == null || modelId <= 0) { + return candidateParses; + } + return candidateParses.stream() + .sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId))) + .limit(candidateParseSize) + .collect(Collectors.toList()); + } + + public List sortParseInfo(List semanticQueries) { + return semanticQueries.stream() + .map(SemanticQuery::getParseInfo) + .sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed()) + .collect(Collectors.toList()); + } + + public void updateParseInfo(SemanticParseInfo parseInfo) { + SqlInfo sqlInfo = parseInfo.getSqlInfo(); + String logicSql = sqlInfo.getLogicSql(); + if (StringUtils.isBlank(logicSql)) { + return; + } + + List expressions = SqlParserSelectHelper.getFilterExpression(logicSql); + //set dataInfo + try { + if (!org.springframework.util.CollectionUtils.isEmpty(expressions)) { + DateConf dateInfo = getDateInfo(expressions); + parseInfo.setDateInfo(dateInfo); + } + } catch (Exception e) { + log.error("set dateInfo error :", e); + } + + //set filter + try { + Map fieldNameToElement = getNameToElement(parseInfo.getModelId()); + List result = getDimensionFilter(fieldNameToElement, expressions); + parseInfo.getDimensionFilters().addAll(result); + } catch (Exception e) { + log.error("set dimensionFilter error :", e); + } + + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + + if (Objects.isNull(semanticSchema)) { + return; + } + List allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getLogicSql())); + + Set metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics()); + parseInfo.setMetrics(metrics); + + if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getLogicSql())) { + parseInfo.setNativeQuery(false); + List groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getLogicSql()); + List groupByDimensions = getFieldsExceptDate(groupByFields); + parseInfo.setDimensions( + getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions())); + } else { + parseInfo.setNativeQuery(true); + List selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getLogicSql()); + List selectDimensions = getFieldsExceptDate(selectFields); + parseInfo.setDimensions( + getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions())); + } + } + + + private Set getElements(Long modelId, List allFields, List elements) { + return elements.stream() + .filter(schemaElement -> modelId.equals(schemaElement.getModel()) + && allFields.contains(schemaElement.getName()) + ).collect(Collectors.toSet()); + } + + private List getFieldsExceptDate(List allFields) { + if (org.springframework.util.CollectionUtils.isEmpty(allFields)) { + return new ArrayList<>(); + } + return allFields.stream() + .filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry)) + .collect(Collectors.toList()); + } + + + private List getDimensionFilter(Map fieldNameToElement, + List filterExpressions) { + List result = Lists.newArrayList(); + for (FilterExpression expression : filterExpressions) { + QueryFilter dimensionFilter = new QueryFilter(); + dimensionFilter.setValue(expression.getFieldValue()); + SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName()); + if (Objects.isNull(schemaElement)) { + continue; + } + dimensionFilter.setName(schemaElement.getName()); + dimensionFilter.setBizName(schemaElement.getBizName()); + dimensionFilter.setElementID(schemaElement.getId()); + + FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator()); + dimensionFilter.setOperator(operatorEnum); + dimensionFilter.setFunction(expression.getFunction()); + result.add(dimensionFilter); + } + return result; + } + + private DateConf getDateInfo(List filterExpressions) { + List dateExpressions = filterExpressions.stream() + .filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName())) + .collect(Collectors.toList()); + if (org.springframework.util.CollectionUtils.isEmpty(dateExpressions)) { + return new DateConf(); + } + DateConf dateInfo = new DateConf(); + dateInfo.setDateMode(DateMode.BETWEEN); + FilterExpression firstExpression = dateExpressions.get(0); + + FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator()); + if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) { + dateInfo.setStartDate(firstExpression.getFieldValue().toString()); + dateInfo.setEndDate(firstExpression.getFieldValue().toString()); + dateInfo.setDateMode(DateMode.BETWEEN); + return dateInfo; + } + if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN, + FilterOperatorEnum.GREATER_THAN_EQUALS)) { + dateInfo.setStartDate(firstExpression.getFieldValue().toString()); + if (hasSecondDate(dateExpressions)) { + dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString()); + } + } + if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN, + FilterOperatorEnum.MINOR_THAN_EQUALS)) { + dateInfo.setEndDate(firstExpression.getFieldValue().toString()); + if (hasSecondDate(dateExpressions)) { + dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString()); + } + } + return dateInfo; + } + + private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator, + FilterOperatorEnum... operatorEnums) { + return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue())); + } + + private boolean hasSecondDate(List dateExpressions) { + return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue()); + } + + protected Map getNameToElement(Long modelId) { + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + List dimensions = semanticSchema.getDimensions(); + List metrics = semanticSchema.getMetrics(); + + List allElements = Lists.newArrayList(); + allElements.addAll(dimensions); + allElements.addAll(metrics); + //support alias + return allElements.stream() + .filter(schemaElement -> schemaElement.getModel().equals(modelId)) + .flatMap(schemaElement -> { + Set> result = new HashSet<>(); + result.add(Pair.of(schemaElement.getName(), schemaElement)); + List aliasList = schemaElement.getAlias(); + if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) { + for (String alias : aliasList) { + result.add(Pair.of(alias, schemaElement)); + } + } + return result.stream(); + }) + .collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2)); + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index c22ae344a..af79b4796 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -31,6 +31,7 @@ import com.tencent.supersonic.chat.query.llm.s2ql.S2QLQuery; import com.tencent.supersonic.chat.responder.execute.ExecuteResponder; import com.tencent.supersonic.chat.responder.parse.ParseResponder; import com.tencent.supersonic.chat.service.ChatService; +import com.tencent.supersonic.chat.service.ParseInfoService; import com.tencent.supersonic.chat.service.QueryService; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.StatisticsService; @@ -58,7 +59,6 @@ import com.tencent.supersonic.knowledge.utils.NatureHelper; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import java.util.ArrayList; -import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -102,6 +102,8 @@ public class QueryServiceImpl implements QueryService { private StatisticsService statisticsService; @Autowired private SolvedQueryManager solvedQueryManager; + @Autowired + private ParseInfoService parseInfoService; @Value("${time.threshold: 100}") private Integer timeThreshold; @@ -148,7 +150,6 @@ public class QueryServiceImpl implements QueryService { semanticCorrectors.stream().forEach(correction -> { correction.correct(queryReq, semanticQuery.getParseInfo()); }); - semanticQuery.updateParseInfo(); } } @@ -156,15 +157,14 @@ public class QueryServiceImpl implements QueryService { ParseResp parseResult; List chatParseDOS = Lists.newArrayList(); if (candidateQueries.size() > 0) { - log.debug("pick before [{}]", candidateQueries.stream().collect( - Collectors.toList())); List selectedQueries = querySelector.select(candidateQueries, queryReq); - log.debug("pick after [{}]", selectedQueries.stream().collect( - Collectors.toList())); - List selectedParses = convertParseInfo(selectedQueries); - List candidateParses = convertParseInfo(candidateQueries); - candidateParses = getTop5CandidateParseInfo(selectedParses, candidateParses); + candidateQueries.forEach(semanticQuery -> parseInfoService.updateParseInfo(semanticQuery.getParseInfo())); + List selectedParses = parseInfoService.sortParseInfo(selectedQueries); + List candidateParses = parseInfoService.sortParseInfo(candidateQueries); + + candidateParses = parseInfoService.getTopCandidateParseInfo(selectedParses, candidateParses); + parseResult = ParseResp.builder() .chatId(queryReq.getChatId()) .queryText(queryReq.getQueryText()) @@ -172,7 +172,7 @@ public class QueryServiceImpl implements QueryService { .selectedParses(selectedParses) .candidateParses(candidateParses) .build(); - chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult, candidateParses, selectedParses); + chatParseDOS = chatService.batchAddParse(chatCtx, queryReq, parseResult); } else { parseResult = ParseResp.builder() .chatId(queryReq.getChatId()) @@ -198,35 +198,6 @@ public class QueryServiceImpl implements QueryService { return parseResult; } - private List convertParseInfo(List semanticQueries) { - return semanticQueries.stream() - .map(SemanticQuery::getParseInfo) - .sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed()) - .collect(Collectors.toList()); - } - - private List getTop5CandidateParseInfo(List selectedParses, - List candidateParses) { - if (CollectionUtils.isEmpty(selectedParses) || CollectionUtils.isEmpty(candidateParses)) { - return candidateParses; - } - int selectParseSize = selectedParses.size(); - Set selectParseScoreSet = selectedParses.stream() - .map(SemanticParseInfo::getScore).collect(Collectors.toSet()); - int candidateParseSize = 5 - selectParseSize; - candidateParses = candidateParses.stream() - .filter(candidateParse -> !selectParseScoreSet.contains(candidateParse.getScore())) - .collect(Collectors.toList()); - SemanticParseInfo semanticParseInfo = selectedParses.get(0); - Long modelId = semanticParseInfo.getModelId(); - if (modelId == null || modelId <= 0) { - return candidateParses; - } - return candidateParses.stream() - .sorted(Comparator.comparing(parse -> !parse.getModelId().equals(modelId))) - .limit(candidateParseSize) - .collect(Collectors.toList()); - } @Override @TimeCost