From b13b38c64501c51eac158fb4e1506417e3f6c1a0 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Sat, 10 Aug 2024 21:47:11 +0800 Subject: [PATCH] (fix)(headless)Fix demo conversations with DETAIL query mode. (fix)(headless)Fix demo conversations with DETAIL query mode. --- .../chat/corrector/GroupByCorrector.java | 5 +++++ .../chat/parser/rule/ContextInheritParser.java | 4 +++- .../headless/chat/query/QueryManager.java | 2 +- .../chat/query/rule/RuleSemanticQuery.java | 4 ++-- .../headless/core/executor/JdbcExecutor.java | 5 +++-- .../server/processor/EntityInfoProcessor.java | 2 +- .../tencent/supersonic/demo/S2ArtistDemo.java | 16 ++++++++-------- .../chat/{TagTest.java => DetailTest.java} | 4 ++-- 8 files changed, 25 insertions(+), 17 deletions(-) rename launchers/standalone/src/test/java/com/tencent/supersonic/chat/{TagTest.java => DetailTest.java} (97%) diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java index c971df07a..ee795001f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.chat.corrector; +import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; @@ -33,6 +34,10 @@ public class GroupByCorrector extends BaseSemanticCorrector { } private Boolean needAddGroupBy(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())) { + return false; + } + Long dataSetId = semanticParseInfo.getDataSetId(); //add dimension group by SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java index 0a18ed30c..433263dd0 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.SemanticQuery; import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.chat.parser.SemanticParser; +import com.tencent.supersonic.headless.chat.query.rule.detail.DetailDimensionQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery; @@ -103,7 +104,8 @@ public class ContextInheritParser implements SemanticParser { protected boolean shouldInherit(ChatQueryContext chatQueryContext) { // if candidates only have MetricModel mode, count in context List metricModelQueries = chatQueryContext.getCandidateQueries().stream() - .filter(query -> query instanceof MetricModelQuery).collect( + .filter(query -> query instanceof MetricModelQuery + || query instanceof DetailDimensionQuery).collect( Collectors.toList()); return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size(); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java index c1d835923..436f40df6 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java @@ -66,7 +66,7 @@ public class QueryManager { return ruleQueryMap.get(queryMode) instanceof MetricSemanticQuery; } - public static boolean isTagQuery(String queryMode) { + public static boolean isDetailQuery(String queryMode) { if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) { return false; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java index ebb502583..73a024212 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java @@ -75,8 +75,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { return; } - if ((QueryManager.isTagQuery(queryParseInfo.getQueryMode()) - && QueryManager.isTagQuery(contextParseInfo.getQueryMode())) + if ((QueryManager.isDetailQuery(queryParseInfo.getQueryMode()) + && QueryManager.isDetailQuery(contextParseInfo.getQueryMode())) || (QueryManager.isMetricQuery(queryParseInfo.getQueryMode()) && QueryManager.isMetricQuery(contextParseInfo.getQueryMode()))) { // inherit date info from context diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java index bc7a7c70e..da629ca2a 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java @@ -34,12 +34,13 @@ public class JdbcExecutor implements QueryExecutor { } SqlUtils sqlUtils = ContextUtils.getBean(SqlUtils.class); - log.info("executing SQL: {}", StringUtils.normalizeSpace(queryStatement.getSql())); + String sql = StringUtils.normalizeSpace(queryStatement.getSql()); + log.info("executing SQL: {}", sql); Database database = queryStatement.getSemanticModel().getDatabase(); SemanticQueryResp queryResultWithColumns = new SemanticQueryResp(); SqlUtils sqlUtil = sqlUtils.init(database); sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns); - queryResultWithColumns.setSql(queryStatement.getSql()); + queryResultWithColumns.setSql(sql); return queryResultWithColumns; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java index fe69f76dc..3681ed453 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java @@ -18,7 +18,7 @@ public class EntityInfoProcessor implements ResultProcessor { public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) { parseResp.getSelectedParses().forEach(parseInfo -> { String queryMode = parseInfo.getQueryMode(); - if (!QueryManager.isTagQuery(queryMode) && !QueryManager.isMetricQuery(queryMode)) { + if (!QueryManager.isDetailQuery(queryMode) && !QueryManager.isMetricQuery(queryMode)) { return; } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java index eb0ab4516..312336d8e 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java @@ -77,7 +77,7 @@ public class S2ArtistDemo extends S2BaseDemo { private TagObjectResp addTagObjectSinger(DomainResp singerDomain) throws Exception { TagObjectReq tagObjectReq = new TagObjectReq(); tagObjectReq.setDomainId(singerDomain.getId()); - tagObjectReq.setName("艺人"); + tagObjectReq.setName("歌手"); tagObjectReq.setBizName("singer"); User user = User.getFakeUser(); return tagObjectService.create(tagObjectReq, user); @@ -85,7 +85,7 @@ public class S2ArtistDemo extends S2BaseDemo { public DomainResp addDomain() { DomainReq domainReq = new DomainReq(); - domainReq.setName("艺人库"); + domainReq.setName("歌手库"); domainReq.setBizName("singer"); domainReq.setParentId(0L); domainReq.setStatus(StatusEnum.ONLINE.getCode()); @@ -100,9 +100,9 @@ public class S2ArtistDemo extends S2BaseDemo { public ModelResp addModel(DomainResp singerDomain, DatabaseResp s2Database, TagObjectResp singerTagObject) throws Exception { ModelReq modelReq = new ModelReq(); - modelReq.setName("艺人库"); + modelReq.setName("歌手库"); modelReq.setBizName("singer"); - modelReq.setDescription("艺人库"); + modelReq.setDescription("歌手库"); modelReq.setDatabaseId(s2Database.getId()); modelReq.setDomainId(singerDomain.getId()); modelReq.setTagObjectId(singerTagObject.getId()); @@ -113,7 +113,7 @@ public class S2ArtistDemo extends S2BaseDemo { ModelDetail modelDetail = new ModelDetail(); List identifiers = new ArrayList<>(); Identify identify = new Identify("歌手名", IdentifyType.primary.name(), "singer_name", 1); - identify.setEntityNames(Lists.newArrayList("歌手", "艺人")); + identify.setEntityNames(Lists.newArrayList("歌手")); identifiers.add(identify); modelDetail.setIdentifiers(identifiers); @@ -152,10 +152,10 @@ public class S2ArtistDemo extends S2BaseDemo { public long addDataSet(DomainResp singerDomain, ModelResp singerModel) { DataSetReq dataSetReq = new DataSetReq(); - dataSetReq.setName("艺人库数据集"); + dataSetReq.setName("歌手库数据集"); dataSetReq.setBizName("singer"); dataSetReq.setDomainId(singerDomain.getId()); - dataSetReq.setDescription("包含艺人相关标签和指标信息"); + dataSetReq.setDescription("包含歌手相关标签和指标信息"); dataSetReq.setAdmins(Lists.newArrayList("admin", "jack")); List dataSetModelConfigs = getDataSetModelConfigs(singerDomain.getId()); DataSetDetail dataSetDetail = new DataSetDetail(); @@ -191,7 +191,7 @@ public class S2ArtistDemo extends S2BaseDemo { agent.setDescription("帮助您用自然语言进行圈选,支持多条件组合筛选"); agent.setStatus(1); agent.setEnableSearch(1); - agent.setExamples(Lists.newArrayList("国风流派艺人", "港台地区的艺人", "流派为流行的艺人")); + agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派")); AgentConfig agentConfig = new AgentConfig(); RuleParserTool ruleQueryTool = new RuleParserTool(); ruleQueryTool.setId("0"); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/TagTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/DetailTest.java similarity index 97% rename from launchers/standalone/src/test/java/com/tencent/supersonic/chat/TagTest.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/chat/DetailTest.java index bfb945c3f..77aa6b707 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/TagTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/DetailTest.java @@ -18,7 +18,7 @@ import org.springframework.boot.test.context.SpringBootTest; @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) @Slf4j -public class TagTest extends BaseTest { +public class DetailTest extends BaseTest { @Test public void test_detail_dimension() throws Exception { @@ -72,7 +72,7 @@ public class TagTest extends BaseTest { @Test public void test_detail_list_filter() throws Exception { - QueryResult actualResult = submitNewChat("国风艺人", DataUtils.tagAgentId); + QueryResult actualResult = submitNewChat("国风歌手", DataUtils.tagAgentId); QueryResult expectedResult = new QueryResult(); SemanticParseInfo expectedParseInfo = new SemanticParseInfo();