mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
[improvement][project]Simplify code logic in multiple modules.
This commit is contained in:
@@ -26,7 +26,7 @@ import java.util.stream.Collectors;
|
|||||||
public class PlainTextExecutor implements ChatQueryExecutor {
|
public class PlainTextExecutor implements ChatQueryExecutor {
|
||||||
|
|
||||||
public static final String APP_KEY = "SMALL_TALK";
|
public static final String APP_KEY = "SMALL_TALK";
|
||||||
private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to."
|
private static final String INSTRUCTION = "#Role: You are a nice person to talk to."
|
||||||
+ "\n#Task: Respond quickly and nicely to the user."
|
+ "\n#Task: Respond quickly and nicely to the user."
|
||||||
+ "\n#Rules: 1.ALWAYS use the same language as the `#Current Input`."
|
+ "\n#Rules: 1.ALWAYS use the same language as the `#Current Input`."
|
||||||
+ "\n#History Inputs: %s" + "\n#Current Input: %s" + "\n#Response: ";
|
+ "\n#History Inputs: %s" + "\n#Current Input: %s" + "\n#Response: ";
|
||||||
|
|||||||
@@ -83,6 +83,9 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
if (Objects.isNull(parseContext.getRequest().getSelectedParse())) {
|
if (Objects.isNull(parseContext.getRequest().getSelectedParse())) {
|
||||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||||
|
if (parseContext.enableLLM()) {
|
||||||
|
queryNLReq.setText2SQLType(Text2SQLType.NONE);
|
||||||
|
}
|
||||||
|
|
||||||
// for every requested dataSet, recursively invoke rule-based parser with different
|
// for every requested dataSet, recursively invoke rule-based parser with different
|
||||||
// mapModes
|
// mapModes
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import java.util.Map;
|
|||||||
|
|
||||||
public class PluginQueryManager {
|
public class PluginQueryManager {
|
||||||
|
|
||||||
private static Map<String, PluginSemanticQuery> pluginQueries = new HashMap<>();
|
private static final Map<String, PluginSemanticQuery> pluginQueries = new HashMap<>();
|
||||||
|
|
||||||
public static void register(String queryMode, PluginSemanticQuery pluginSemanticQuery) {
|
public static void register(String queryMode, PluginSemanticQuery pluginSemanticQuery) {
|
||||||
pluginQueries.put(queryMode, pluginSemanticQuery);
|
pluginQueries.put(queryMode, pluginSemanticQuery);
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.server.processor.parse;
|
package com.tencent.supersonic.chat.server.processor.parse;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||||
@@ -22,7 +21,12 @@ import org.apache.commons.lang3.StringUtils;
|
|||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.*;
|
import java.util.Arrays;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -33,8 +37,7 @@ public class ParseInfoFormatProcessor implements ParseResultProcessor {
|
|||||||
@Override
|
@Override
|
||||||
public void process(ParseContext parseContext) {
|
public void process(ParseContext parseContext) {
|
||||||
parseContext.getResponse().getSelectedParses().forEach(p -> {
|
parseContext.getResponse().getSelectedParses().forEach(p -> {
|
||||||
if (PluginQueryManager.isPluginQuery(p.getQueryMode())
|
if (Objects.isNull(p.getDataSet()) || Objects.isNull(p.getSqlInfo().getParsedS2SQL())) {
|
||||||
|| "PLAIN_TEXT".equals(p.getQueryMode())) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.common.pojo.enums;
|
package com.tencent.supersonic.common.pojo.enums;
|
||||||
|
|
||||||
public enum Text2SQLType {
|
public enum Text2SQLType {
|
||||||
ONLY_RULE, LLM_OR_RULE;
|
ONLY_RULE, LLM_OR_RULE, NONE;
|
||||||
|
|
||||||
public boolean enableLLM() {
|
public boolean enableLLM() {
|
||||||
return this.equals(LLM_OR_RULE);
|
return this.equals(LLM_OR_RULE);
|
||||||
|
|||||||
@@ -26,26 +26,27 @@ import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT;
|
|||||||
public class SemanticParseInfo implements Serializable {
|
public class SemanticParseInfo implements Serializable {
|
||||||
|
|
||||||
private Integer id;
|
private Integer id;
|
||||||
private String queryMode = "PLAIN_TEXT";
|
private String queryMode = "";
|
||||||
private QueryConfig queryConfig;
|
private QueryConfig queryConfig;
|
||||||
private QueryType queryType = QueryType.DETAIL;
|
private QueryType queryType;
|
||||||
|
|
||||||
private SchemaElement dataSet;
|
private SchemaElement dataSet;
|
||||||
private Set<SchemaElement> metrics = Sets.newTreeSet(new SchemaNameLengthComparator());
|
private Set<SchemaElement> metrics = Sets.newTreeSet(new SchemaNameLengthComparator());
|
||||||
private Set<SchemaElement> dimensions = Sets.newTreeSet(new SchemaNameLengthComparator());
|
private Set<SchemaElement> dimensions = Sets.newTreeSet(new SchemaNameLengthComparator());
|
||||||
|
|
||||||
private Set<QueryFilter> dimensionFilters = Sets.newHashSet();
|
private Set<QueryFilter> dimensionFilters = Sets.newHashSet();
|
||||||
private Set<QueryFilter> metricFilters = Sets.newHashSet();
|
private Set<QueryFilter> metricFilters = Sets.newHashSet();
|
||||||
|
private FilterType filterType = FilterType.AND;
|
||||||
|
|
||||||
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
||||||
private FilterType filterType = FilterType.AND;
|
|
||||||
private Set<Order> orders = Sets.newHashSet();
|
private Set<Order> orders = Sets.newHashSet();
|
||||||
private DateConf dateInfo;
|
|
||||||
private long limit = DEFAULT_DETAIL_LIMIT;
|
private long limit = DEFAULT_DETAIL_LIMIT;
|
||||||
private double score;
|
private double score;
|
||||||
private List<SchemaElementMatch> elementMatches = Lists.newArrayList();
|
private List<SchemaElementMatch> elementMatches = Lists.newArrayList();
|
||||||
|
private DateConf dateInfo;
|
||||||
private SqlInfo sqlInfo = new SqlInfo();
|
private SqlInfo sqlInfo = new SqlInfo();
|
||||||
private SqlEvaluation sqlEvaluation = new SqlEvaluation();
|
|
||||||
private String textInfo;
|
private String textInfo;
|
||||||
|
private SqlEvaluation sqlEvaluation = new SqlEvaluation();
|
||||||
private Map<String, Object> properties = Maps.newHashMap();
|
private Map<String, Object> properties = Maps.newHashMap();
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.chat;
|
package com.tencent.supersonic.headless.chat;
|
||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
@@ -41,6 +42,10 @@ public class ChatQueryContext implements Serializable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean needSQL() {
|
||||||
|
return !request.getText2SQLType().equals(Text2SQLType.NONE);
|
||||||
|
}
|
||||||
|
|
||||||
public DataSetSchema getDataSetSchema(Long dataSetId) {
|
public DataSetSchema getDataSetSchema(Long dataSetId) {
|
||||||
return semanticSchema.getDataSetSchema(dataSetId);
|
return semanticSchema.getDataSetSchema(dataSetId);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -31,9 +30,11 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
|
|
||||||
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
try {
|
try {
|
||||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
String s2SQL = semanticParseInfo.getSqlInfo().getParsedS2SQL();
|
||||||
|
if (Objects.isNull(s2SQL)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
|
||||||
doCorrect(chatQueryContext, semanticParseInfo);
|
doCorrect(chatQueryContext, semanticParseInfo);
|
||||||
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(),
|
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(),
|
||||||
semanticParseInfo.getSqlInfo());
|
semanticParseInfo.getSqlInfo());
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */
|
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class QueryTypeParser implements SemanticParser {
|
public class QueryTypeParser implements SemanticParser {
|
||||||
@@ -15,12 +17,14 @@ public class QueryTypeParser implements SemanticParser {
|
|||||||
chatQueryContext.getCandidateQueries().forEach(query -> {
|
chatQueryContext.getCandidateQueries().forEach(query -> {
|
||||||
SemanticParseInfo parseInfo = query.getParseInfo();
|
SemanticParseInfo parseInfo = query.getParseInfo();
|
||||||
String s2SQL = parseInfo.getSqlInfo().getParsedS2SQL();
|
String s2SQL = parseInfo.getSqlInfo().getParsedS2SQL();
|
||||||
QueryType queryType = QueryType.DETAIL;
|
if (Objects.isNull(s2SQL)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
QueryType queryType = QueryType.DETAIL;
|
||||||
if (SqlSelectFunctionHelper.hasAggregateFunction(s2SQL)) {
|
if (SqlSelectFunctionHelper.hasAggregateFunction(s2SQL)) {
|
||||||
queryType = QueryType.AGGREGATE;
|
queryType = QueryType.AGGREGATE;
|
||||||
}
|
}
|
||||||
|
|
||||||
parseInfo.setQueryType(queryType);
|
parseInfo.setQueryType(queryType);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,7 +49,6 @@ public class LLMResponseService {
|
|||||||
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
|
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
|
||||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||||
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
|
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
|
||||||
parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
|
|
||||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,9 @@ public class RuleSqlParser implements SemanticParser {
|
|||||||
|
|
||||||
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
|
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
|
||||||
|
|
||||||
candidateQueries.forEach(query -> query.buildS2Sql(
|
if (chatQueryContext.needSQL()) {
|
||||||
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
|
candidateQueries.forEach(query -> query.buildS2Sql(
|
||||||
|
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,16 +3,17 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
|||||||
import com.fasterxml.jackson.annotation.JsonValue;
|
import com.fasterxml.jackson.annotation.JsonValue;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -45,22 +46,23 @@ public class LLMReq {
|
|||||||
private SchemaElement primaryKey;
|
private SchemaElement primaryKey;
|
||||||
|
|
||||||
public List<String> getFieldNameList() {
|
public List<String> getFieldNameList() {
|
||||||
List<String> fieldNameList = new ArrayList<>();
|
Set<String> fieldNameList = new HashSet<>();
|
||||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||||
fieldNameList.addAll(metrics.stream().map(metric -> metric.getName())
|
fieldNameList.addAll(
|
||||||
.collect(Collectors.toList()));
|
metrics.stream().map(SchemaElement::getName).collect(Collectors.toList()));
|
||||||
}
|
}
|
||||||
if (CollectionUtils.isNotEmpty(dimensions)) {
|
if (CollectionUtils.isNotEmpty(dimensions)) {
|
||||||
fieldNameList.addAll(dimensions.stream().map(dimension -> dimension.getName())
|
fieldNameList.addAll(dimensions.stream().map(SchemaElement::getName)
|
||||||
|
.collect(Collectors.toList()));
|
||||||
|
}
|
||||||
|
if (CollectionUtils.isNotEmpty(values)) {
|
||||||
|
fieldNameList.addAll(values.stream().map(ElementValue::getFieldName)
|
||||||
.collect(Collectors.toList()));
|
.collect(Collectors.toList()));
|
||||||
}
|
}
|
||||||
if (Objects.nonNull(partitionTime)) {
|
if (Objects.nonNull(partitionTime)) {
|
||||||
fieldNameList.add(partitionTime.getName());
|
fieldNameList.add(partitionTime.getName());
|
||||||
}
|
}
|
||||||
if (Objects.nonNull(primaryKey)) {
|
return new ArrayList<>(fieldNameList);
|
||||||
fieldNameList.add(primaryKey.getName());
|
|
||||||
}
|
|
||||||
return fieldNameList;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,7 +76,7 @@ public class LLMReq {
|
|||||||
public enum SqlGenType {
|
public enum SqlGenType {
|
||||||
ONE_PASS_SELF_CONSISTENCY("1_pass_self_consistency");
|
ONE_PASS_SELF_CONSISTENCY("1_pass_self_consistency");
|
||||||
|
|
||||||
private String name;
|
private final String name;
|
||||||
|
|
||||||
SqlGenType(String name) {
|
SqlGenType(String name) {
|
||||||
this.name = name;
|
this.name = name;
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
|
||||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -23,8 +22,5 @@ public class LLMSqlQuery extends LLMSemanticQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void buildS2Sql(DataSetSchema dataSetSchema) {
|
public void buildS2Sql(DataSetSchema dataSetSchema) {}
|
||||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
|
||||||
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,7 +60,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
|||||||
convertBizNameToName(dataSetSchema, queryStructReq);
|
convertBizNameToName(dataSetSchema, queryStructReq);
|
||||||
QuerySqlReq querySQLReq = queryStructReq.convert();
|
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||||
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
|
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
|
||||||
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected QueryStructReq convertQueryStruct() {
|
protected QueryStructReq convertQueryStruct() {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.query.rule.metric;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.Order;
|
import com.tencent.supersonic.common.pojo.Order;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
@@ -52,8 +51,6 @@ public class MetricTopNQuery extends MetricSemanticQuery {
|
|||||||
super.fillParseInfo(chatQueryContext, dataSetId);
|
super.fillParseInfo(chatQueryContext, dataSetId);
|
||||||
|
|
||||||
parseInfo.setScore(parseInfo.getScore() + 2.0);
|
parseInfo.setScore(parseInfo.getScore() + 2.0);
|
||||||
parseInfo.setAggType(AggregateTypeEnum.SUM);
|
|
||||||
|
|
||||||
SchemaElement metric = parseInfo.getMetrics().iterator().next();
|
SchemaElement metric = parseInfo.getMetrics().iterator().next();
|
||||||
parseInfo.getOrders().add(new Order(metric.getBizName(), Constants.DESC_UPPER));
|
parseInfo.getOrders().add(new Order(metric.getBizName(), Constants.DESC_UPPER));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
|
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
|
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||||
@@ -128,10 +127,7 @@ public class S2ChatLayerService implements ChatLayerService {
|
|||||||
schemaService.getSemanticSchema(Sets.newHashSet(querySqlReq.getDataSetId()));
|
schemaService.getSemanticSchema(Sets.newHashSet(querySqlReq.getDataSetId()));
|
||||||
queryCtx.setSemanticSchema(semanticSchema);
|
queryCtx.setSemanticSchema(semanticSchema);
|
||||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
SqlInfo sqlInfo = new SqlInfo();
|
semanticParseInfo.getSqlInfo().setParsedS2SQL(querySqlReq.getSql());
|
||||||
sqlInfo.setCorrectedS2SQL(querySqlReq.getSql());
|
|
||||||
sqlInfo.setParsedS2SQL(querySqlReq.getSql());
|
|
||||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
|
||||||
semanticParseInfo.setQueryType(QueryType.DETAIL);
|
semanticParseInfo.setQueryType(QueryType.DETAIL);
|
||||||
|
|
||||||
Long dataSetId = querySqlReq.getDataSetId();
|
Long dataSetId = querySqlReq.getDataSetId();
|
||||||
@@ -147,7 +143,7 @@ public class S2ChatLayerService implements ChatLayerService {
|
|||||||
corrector.correct(queryCtx, semanticParseInfo);
|
corrector.correct(queryCtx, semanticParseInfo);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectedS2SQL());
|
log.info("Corrected SQL:{}", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||||
return semanticParseInfo;
|
return semanticParseInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,12 @@ public class ChatWorkflowEngine {
|
|||||||
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
|
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
|
||||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||||
parseResult.setSelectedParses(parseInfos);
|
parseResult.setSelectedParses(parseInfos);
|
||||||
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
|
if (queryCtx.needSQL()) {
|
||||||
|
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
|
||||||
|
} else {
|
||||||
|
parseResult.setState(ParseResp.ParseState.COMPLETED);
|
||||||
|
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case CORRECTING:
|
case CORRECTING:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import java.text.SimpleDateFormat;
|
|||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.MAX;
|
||||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
|
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
|
||||||
|
|
||||||
@@ -46,6 +47,7 @@ public class MetricTest extends BaseTest {
|
|||||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||||
|
|
||||||
assertQueryResult(expectedResult, actualResult);
|
assertQueryResult(expectedResult, actualResult);
|
||||||
|
assert actualResult.getQueryResults().size() == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -67,6 +69,7 @@ public class MetricTest extends BaseTest {
|
|||||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||||
|
|
||||||
assertQueryResult(expectedResult, actualResult);
|
assertQueryResult(expectedResult, actualResult);
|
||||||
|
assert actualResult.getQueryResults().size() == 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -93,6 +96,7 @@ public class MetricTest extends BaseTest {
|
|||||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||||
|
|
||||||
assertQueryResult(expectedResult, actualResult);
|
assertQueryResult(expectedResult, actualResult);
|
||||||
|
assert actualResult.getQueryResults().size() == 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -105,7 +109,7 @@ public class MetricTest extends BaseTest {
|
|||||||
expectedResult.setChatContext(expectedParseInfo);
|
expectedResult.setChatContext(expectedParseInfo);
|
||||||
|
|
||||||
expectedResult.setQueryMode(MetricTopNQuery.QUERY_MODE);
|
expectedResult.setQueryMode(MetricTopNQuery.QUERY_MODE);
|
||||||
expectedParseInfo.setAggType(SUM);
|
expectedParseInfo.setAggType(MAX);
|
||||||
|
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||||
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户"));
|
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户"));
|
||||||
@@ -135,6 +139,7 @@ public class MetricTest extends BaseTest {
|
|||||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||||
|
|
||||||
assertQueryResult(expectedResult, actualResult);
|
assertQueryResult(expectedResult, actualResult);
|
||||||
|
assert actualResult.getQueryResults().size() == 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@@ -144,7 +149,7 @@ public class MetricTest extends BaseTest {
|
|||||||
String dateStr = textFormat.format(format.parse(startDay));
|
String dateStr = textFormat.format(format.parse(startDay));
|
||||||
|
|
||||||
QueryResult actualResult =
|
QueryResult actualResult =
|
||||||
submitNewChat(String.format("想知道%salice的访问次数", dateStr), DataUtils.metricAgentId);
|
submitNewChat(String.format("alice在%s的访问次数", dateStr), DataUtils.metricAgentId);
|
||||||
|
|
||||||
QueryResult expectedResult = new QueryResult();
|
QueryResult expectedResult = new QueryResult();
|
||||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||||
|
|||||||
Reference in New Issue
Block a user