diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java index 42c7c30d1..adcab30a3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -26,7 +26,7 @@ import java.util.stream.Collectors; public class PlainTextExecutor implements ChatQueryExecutor { public static final String APP_KEY = "SMALL_TALK"; - private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to." + private static final String INSTRUCTION = "#Role: You are a nice person to talk to." + "\n#Task: Respond quickly and nicely to the user." + "\n#Rules: 1.ALWAYS use the same language as the `#Current Input`." + "\n#History Inputs: %s" + "\n#Current Input: %s" + "\n#Response: "; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 50401ad23..de06afcaf 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -83,6 +83,9 @@ public class NL2SQLParser implements ChatQueryParser { if (Objects.isNull(parseContext.getRequest().getSelectedParse())) { QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); + if (parseContext.enableLLM()) { + queryNLReq.setText2SQLType(Text2SQLType.NONE); + } // for every requested dataSet, recursively invoke rule-based parser with different // mapModes diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginQueryManager.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginQueryManager.java index 7e3099c70..2860db195 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginQueryManager.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginQueryManager.java @@ -7,7 +7,7 @@ import java.util.Map; public class PluginQueryManager { - private static Map pluginQueries = new HashMap<>(); + private static final Map pluginQueries = new HashMap<>(); public static void register(String queryMode, PluginSemanticQuery pluginSemanticQuery) { pluginQueries.put(queryMode, pluginSemanticQuery); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoFormatProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoFormatProcessor.java index ed8d8694e..627f09fd7 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoFormatProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoFormatProcessor.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.chat.server.processor.parse; import com.google.common.collect.Lists; -import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.common.jsqlparser.FieldExpression; import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper; @@ -22,7 +21,12 @@ import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.util.CollectionUtils; -import java.util.*; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; /** @@ -33,8 +37,7 @@ public class ParseInfoFormatProcessor implements ParseResultProcessor { @Override public void process(ParseContext parseContext) { parseContext.getResponse().getSelectedParses().forEach(p -> { - if (PluginQueryManager.isPluginQuery(p.getQueryMode()) - || "PLAIN_TEXT".equals(p.getQueryMode())) { + if (Objects.isNull(p.getDataSet()) || Objects.isNull(p.getSqlInfo().getParsedS2SQL())) { return; } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java index 150f4f1a3..0b2940633 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.common.pojo.enums; public enum Text2SQLType { - ONLY_RULE, LLM_OR_RULE; + ONLY_RULE, LLM_OR_RULE, NONE; public boolean enableLLM() { return this.equals(LLM_OR_RULE); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java index d25381317..46bc7fad1 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java @@ -26,26 +26,27 @@ import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT; public class SemanticParseInfo implements Serializable { private Integer id; - private String queryMode = "PLAIN_TEXT"; + private String queryMode = ""; private QueryConfig queryConfig; - private QueryType queryType = QueryType.DETAIL; + private QueryType queryType; private SchemaElement dataSet; private Set metrics = Sets.newTreeSet(new SchemaNameLengthComparator()); private Set dimensions = Sets.newTreeSet(new SchemaNameLengthComparator()); + private Set dimensionFilters = Sets.newHashSet(); private Set metricFilters = Sets.newHashSet(); + private FilterType filterType = FilterType.AND; private AggregateTypeEnum aggType = AggregateTypeEnum.NONE; - private FilterType filterType = FilterType.AND; private Set orders = Sets.newHashSet(); - private DateConf dateInfo; private long limit = DEFAULT_DETAIL_LIMIT; private double score; private List elementMatches = Lists.newArrayList(); + private DateConf dateInfo; private SqlInfo sqlInfo = new SqlInfo(); - private SqlEvaluation sqlEvaluation = new SqlEvaluation(); private String textInfo; + private SqlEvaluation sqlEvaluation = new SqlEvaluation(); private Map properties = Maps.newHashMap(); @Data diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index b1167c78b..1dccdb346 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.chat; import com.fasterxml.jackson.annotation.JsonIgnore; +import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; @@ -41,6 +42,10 @@ public class ChatQueryContext implements Serializable { } } + public boolean needSQL() { + return !request.getText2SQLType().equals(Text2SQLType.NONE); + } + public DataSetSchema getDataSetSchema(Long dataSetId) { return semanticSchema.getDataSetSchema(dataSetId); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java index 9f324f8c3..fe8ab673f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java @@ -10,7 +10,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.util.CollectionUtils; @@ -31,9 +30,11 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { try { - if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) { + String s2SQL = semanticParseInfo.getSqlInfo().getParsedS2SQL(); + if (Objects.isNull(s2SQL)) { return; } + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL); doCorrect(chatQueryContext, semanticParseInfo); log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java index c11182988..178c6958c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java @@ -6,6 +6,8 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; +import java.util.Objects; + /** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */ @Slf4j public class QueryTypeParser implements SemanticParser { @@ -15,12 +17,14 @@ public class QueryTypeParser implements SemanticParser { chatQueryContext.getCandidateQueries().forEach(query -> { SemanticParseInfo parseInfo = query.getParseInfo(); String s2SQL = parseInfo.getSqlInfo().getParsedS2SQL(); - QueryType queryType = QueryType.DETAIL; + if (Objects.isNull(s2SQL)) { + return; + } + QueryType queryType = QueryType.DETAIL; if (SqlSelectFunctionHelper.hasAggregateFunction(s2SQL)) { queryType = QueryType.AGGREGATE; } - parseInfo.setQueryType(queryType); }); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java index 59b1d8560..9b0aeb554 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java @@ -49,7 +49,6 @@ public class LLMResponseService { parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight)); parseInfo.setQueryMode(semanticQuery.getQueryMode()); parseInfo.getSqlInfo().setParsedS2SQL(s2SQL); - parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL); queryCtx.getCandidateQueries().add(semanticQuery); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java index bb3e65305..c8feb94ba 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java @@ -40,7 +40,9 @@ public class RuleSqlParser implements SemanticParser { auxiliaryParsers.forEach(p -> p.parse(chatQueryContext)); - candidateQueries.forEach(query -> query.buildS2Sql( - chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId()))); + if (chatQueryContext.needSQL()) { + candidateQueries.forEach(query -> query.buildS2Sql( + chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId()))); + } } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java index 82622d932..944ef25fd 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java @@ -3,16 +3,17 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql; import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.ChatApp; -import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import lombok.Data; import org.apache.commons.collections4.CollectionUtils; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; @Data @@ -45,22 +46,23 @@ public class LLMReq { private SchemaElement primaryKey; public List getFieldNameList() { - List fieldNameList = new ArrayList<>(); + Set fieldNameList = new HashSet<>(); if (CollectionUtils.isNotEmpty(metrics)) { - fieldNameList.addAll(metrics.stream().map(metric -> metric.getName()) - .collect(Collectors.toList())); + fieldNameList.addAll( + metrics.stream().map(SchemaElement::getName).collect(Collectors.toList())); } if (CollectionUtils.isNotEmpty(dimensions)) { - fieldNameList.addAll(dimensions.stream().map(dimension -> dimension.getName()) + fieldNameList.addAll(dimensions.stream().map(SchemaElement::getName) + .collect(Collectors.toList())); + } + if (CollectionUtils.isNotEmpty(values)) { + fieldNameList.addAll(values.stream().map(ElementValue::getFieldName) .collect(Collectors.toList())); } if (Objects.nonNull(partitionTime)) { fieldNameList.add(partitionTime.getName()); } - if (Objects.nonNull(primaryKey)) { - fieldNameList.add(primaryKey.getName()); - } - return fieldNameList; + return new ArrayList<>(fieldNameList); } } @@ -74,7 +76,7 @@ public class LLMReq { public enum SqlGenType { ONE_PASS_SELF_CONSISTENCY("1_pass_self_consistency"); - private String name; + private final String name; SqlGenType(String name) { this.name = name; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java index 775bb4202..76f5fd880 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; -import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery; import lombok.extern.slf4j.Slf4j; @@ -23,8 +22,5 @@ public class LLMSqlQuery extends LLMSemanticQuery { } @Override - public void buildS2Sql(DataSetSchema dataSetSchema) { - SqlInfo sqlInfo = parseInfo.getSqlInfo(); - sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL()); - } + public void buildS2Sql(DataSetSchema dataSetSchema) {} } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java index a2c0895ca..7552cb377 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java @@ -60,7 +60,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { convertBizNameToName(dataSetSchema, queryStructReq); QuerySqlReq querySQLReq = queryStructReq.convert(); parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql()); - parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql()); } protected QueryStructReq convertQueryStruct() { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java index 6349ed1e7..90624f5f0 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java @@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.query.rule.metric; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Order; -import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.chat.ChatQueryContext; @@ -52,8 +51,6 @@ public class MetricTopNQuery extends MetricSemanticQuery { super.fillParseInfo(chatQueryContext, dataSetId); parseInfo.setScore(parseInfo.getScore() + 2.0); - parseInfo.setAggType(AggregateTypeEnum.SUM); - SchemaElement metric = parseInfo.getMetrics().iterator().next(); parseInfo.getOrders().add(new Order(metric.getBizName(), Constants.DESC_UPPER)); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java index ff4c4541a..515c18115 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java @@ -14,7 +14,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SqlEvaluation; -import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState; import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; @@ -128,10 +127,7 @@ public class S2ChatLayerService implements ChatLayerService { schemaService.getSemanticSchema(Sets.newHashSet(querySqlReq.getDataSetId())); queryCtx.setSemanticSchema(semanticSchema); SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); - SqlInfo sqlInfo = new SqlInfo(); - sqlInfo.setCorrectedS2SQL(querySqlReq.getSql()); - sqlInfo.setParsedS2SQL(querySqlReq.getSql()); - semanticParseInfo.setSqlInfo(sqlInfo); + semanticParseInfo.getSqlInfo().setParsedS2SQL(querySqlReq.getSql()); semanticParseInfo.setQueryType(QueryType.DETAIL); Long dataSetId = querySqlReq.getDataSetId(); @@ -147,7 +143,7 @@ public class S2ChatLayerService implements ChatLayerService { corrector.correct(queryCtx, semanticParseInfo); } }); - log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectedS2SQL()); + log.info("Corrected SQL:{}", semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); return semanticParseInfo; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java index 516df7cd3..b190949ee 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -60,7 +60,12 @@ public class ChatWorkflowEngine { List parseInfos = queryCtx.getCandidateQueries().stream() .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); parseResult.setSelectedParses(parseInfos); - queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING); + if (queryCtx.needSQL()) { + queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING); + } else { + parseResult.setState(ParseResp.ParseState.COMPLETED); + queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); + } } break; case CORRECTING: diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java index 036945278..83c9042b4 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java @@ -20,6 +20,7 @@ import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.List; +import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.MAX; import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE; import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM; @@ -46,6 +47,7 @@ public class MetricTest extends BaseTest { expectedParseInfo.setQueryType(QueryType.AGGREGATE); assertQueryResult(expectedResult, actualResult); + assert actualResult.getQueryResults().size() == 1; } @Test @@ -67,6 +69,7 @@ public class MetricTest extends BaseTest { expectedParseInfo.setQueryType(QueryType.AGGREGATE); assertQueryResult(expectedResult, actualResult); + assert actualResult.getQueryResults().size() == 4; } @Test @@ -93,6 +96,7 @@ public class MetricTest extends BaseTest { expectedParseInfo.setQueryType(QueryType.AGGREGATE); assertQueryResult(expectedResult, actualResult); + assert actualResult.getQueryResults().size() == 2; } @Test @@ -105,7 +109,7 @@ public class MetricTest extends BaseTest { expectedResult.setChatContext(expectedParseInfo); expectedResult.setQueryMode(MetricTopNQuery.QUERY_MODE); - expectedParseInfo.setAggType(SUM); + expectedParseInfo.setAggType(MAX); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户")); @@ -135,6 +139,7 @@ public class MetricTest extends BaseTest { expectedParseInfo.setQueryType(QueryType.AGGREGATE); assertQueryResult(expectedResult, actualResult); + assert actualResult.getQueryResults().size() == 4; } @Test @@ -144,7 +149,7 @@ public class MetricTest extends BaseTest { String dateStr = textFormat.format(format.parse(startDay)); QueryResult actualResult = - submitNewChat(String.format("想知道%salice的访问次数", dateStr), DataUtils.metricAgentId); + submitNewChat(String.format("alice在%s的访问次数", dateStr), DataUtils.metricAgentId); QueryResult expectedResult = new QueryResult(); SemanticParseInfo expectedParseInfo = new SemanticParseInfo();