(fix)(headless)Fix demo conversations with DETAIL query mode.

(fix)(headless)Fix demo conversations with DETAIL query mode.
This commit is contained in:
jerryjzhang
2024-08-10 21:47:11 +08:00
parent 68952fdb55
commit b13b38c645
8 changed files with 25 additions and 17 deletions

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.chat.corrector; package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
@@ -33,6 +34,10 @@ public class GroupByCorrector extends BaseSemanticCorrector {
} }
private Boolean needAddGroupBy(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { private Boolean needAddGroupBy(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
return false;
}
Long dataSetId = semanticParseInfo.getDataSetId(); Long dataSetId = semanticParseInfo.getDataSetId();
//add dimension group by //add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.SemanticQuery; import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.chat.parser.SemanticParser; import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.query.rule.detail.DetailDimensionQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery;
@@ -103,7 +104,8 @@ public class ContextInheritParser implements SemanticParser {
protected boolean shouldInherit(ChatQueryContext chatQueryContext) { protected boolean shouldInherit(ChatQueryContext chatQueryContext) {
// if candidates only have MetricModel mode, count in context // if candidates only have MetricModel mode, count in context
List<SemanticQuery> metricModelQueries = chatQueryContext.getCandidateQueries().stream() List<SemanticQuery> metricModelQueries = chatQueryContext.getCandidateQueries().stream()
.filter(query -> query instanceof MetricModelQuery).collect( .filter(query -> query instanceof MetricModelQuery
|| query instanceof DetailDimensionQuery).collect(
Collectors.toList()); Collectors.toList());
return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size(); return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size();
} }

View File

@@ -66,7 +66,7 @@ public class QueryManager {
return ruleQueryMap.get(queryMode) instanceof MetricSemanticQuery; return ruleQueryMap.get(queryMode) instanceof MetricSemanticQuery;
} }
public static boolean isTagQuery(String queryMode) { public static boolean isDetailQuery(String queryMode) {
if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) { if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) {
return false; return false;
} }

View File

@@ -75,8 +75,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
return; return;
} }
if ((QueryManager.isTagQuery(queryParseInfo.getQueryMode()) if ((QueryManager.isDetailQuery(queryParseInfo.getQueryMode())
&& QueryManager.isTagQuery(contextParseInfo.getQueryMode())) && QueryManager.isDetailQuery(contextParseInfo.getQueryMode()))
|| (QueryManager.isMetricQuery(queryParseInfo.getQueryMode()) || (QueryManager.isMetricQuery(queryParseInfo.getQueryMode())
&& QueryManager.isMetricQuery(contextParseInfo.getQueryMode()))) { && QueryManager.isMetricQuery(contextParseInfo.getQueryMode()))) {
// inherit date info from context // inherit date info from context

View File

@@ -34,12 +34,13 @@ public class JdbcExecutor implements QueryExecutor {
} }
SqlUtils sqlUtils = ContextUtils.getBean(SqlUtils.class); SqlUtils sqlUtils = ContextUtils.getBean(SqlUtils.class);
log.info("executing SQL: {}", StringUtils.normalizeSpace(queryStatement.getSql())); String sql = StringUtils.normalizeSpace(queryStatement.getSql());
log.info("executing SQL: {}", sql);
Database database = queryStatement.getSemanticModel().getDatabase(); Database database = queryStatement.getSemanticModel().getDatabase();
SemanticQueryResp queryResultWithColumns = new SemanticQueryResp(); SemanticQueryResp queryResultWithColumns = new SemanticQueryResp();
SqlUtils sqlUtil = sqlUtils.init(database); SqlUtils sqlUtil = sqlUtils.init(database);
sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns); sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns);
queryResultWithColumns.setSql(queryStatement.getSql()); queryResultWithColumns.setSql(sql);
return queryResultWithColumns; return queryResultWithColumns;
} }

View File

@@ -18,7 +18,7 @@ public class EntityInfoProcessor implements ResultProcessor {
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) { public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) {
parseResp.getSelectedParses().forEach(parseInfo -> { parseResp.getSelectedParses().forEach(parseInfo -> {
String queryMode = parseInfo.getQueryMode(); String queryMode = parseInfo.getQueryMode();
if (!QueryManager.isTagQuery(queryMode) && !QueryManager.isMetricQuery(queryMode)) { if (!QueryManager.isDetailQuery(queryMode) && !QueryManager.isMetricQuery(queryMode)) {
return; return;
} }

View File

@@ -77,7 +77,7 @@ public class S2ArtistDemo extends S2BaseDemo {
private TagObjectResp addTagObjectSinger(DomainResp singerDomain) throws Exception { private TagObjectResp addTagObjectSinger(DomainResp singerDomain) throws Exception {
TagObjectReq tagObjectReq = new TagObjectReq(); TagObjectReq tagObjectReq = new TagObjectReq();
tagObjectReq.setDomainId(singerDomain.getId()); tagObjectReq.setDomainId(singerDomain.getId());
tagObjectReq.setName("艺人"); tagObjectReq.setName("歌手");
tagObjectReq.setBizName("singer"); tagObjectReq.setBizName("singer");
User user = User.getFakeUser(); User user = User.getFakeUser();
return tagObjectService.create(tagObjectReq, user); return tagObjectService.create(tagObjectReq, user);
@@ -85,7 +85,7 @@ public class S2ArtistDemo extends S2BaseDemo {
public DomainResp addDomain() { public DomainResp addDomain() {
DomainReq domainReq = new DomainReq(); DomainReq domainReq = new DomainReq();
domainReq.setName("艺人"); domainReq.setName("歌手");
domainReq.setBizName("singer"); domainReq.setBizName("singer");
domainReq.setParentId(0L); domainReq.setParentId(0L);
domainReq.setStatus(StatusEnum.ONLINE.getCode()); domainReq.setStatus(StatusEnum.ONLINE.getCode());
@@ -100,9 +100,9 @@ public class S2ArtistDemo extends S2BaseDemo {
public ModelResp addModel(DomainResp singerDomain, public ModelResp addModel(DomainResp singerDomain,
DatabaseResp s2Database, TagObjectResp singerTagObject) throws Exception { DatabaseResp s2Database, TagObjectResp singerTagObject) throws Exception {
ModelReq modelReq = new ModelReq(); ModelReq modelReq = new ModelReq();
modelReq.setName("艺人"); modelReq.setName("歌手");
modelReq.setBizName("singer"); modelReq.setBizName("singer");
modelReq.setDescription("艺人"); modelReq.setDescription("歌手");
modelReq.setDatabaseId(s2Database.getId()); modelReq.setDatabaseId(s2Database.getId());
modelReq.setDomainId(singerDomain.getId()); modelReq.setDomainId(singerDomain.getId());
modelReq.setTagObjectId(singerTagObject.getId()); modelReq.setTagObjectId(singerTagObject.getId());
@@ -113,7 +113,7 @@ public class S2ArtistDemo extends S2BaseDemo {
ModelDetail modelDetail = new ModelDetail(); ModelDetail modelDetail = new ModelDetail();
List<Identify> identifiers = new ArrayList<>(); List<Identify> identifiers = new ArrayList<>();
Identify identify = new Identify("歌手名", IdentifyType.primary.name(), "singer_name", 1); Identify identify = new Identify("歌手名", IdentifyType.primary.name(), "singer_name", 1);
identify.setEntityNames(Lists.newArrayList("歌手", "艺人")); identify.setEntityNames(Lists.newArrayList("歌手"));
identifiers.add(identify); identifiers.add(identify);
modelDetail.setIdentifiers(identifiers); modelDetail.setIdentifiers(identifiers);
@@ -152,10 +152,10 @@ public class S2ArtistDemo extends S2BaseDemo {
public long addDataSet(DomainResp singerDomain, ModelResp singerModel) { public long addDataSet(DomainResp singerDomain, ModelResp singerModel) {
DataSetReq dataSetReq = new DataSetReq(); DataSetReq dataSetReq = new DataSetReq();
dataSetReq.setName("艺人库数据集"); dataSetReq.setName("歌手库数据集");
dataSetReq.setBizName("singer"); dataSetReq.setBizName("singer");
dataSetReq.setDomainId(singerDomain.getId()); dataSetReq.setDomainId(singerDomain.getId());
dataSetReq.setDescription("包含艺人相关标签和指标信息"); dataSetReq.setDescription("包含歌手相关标签和指标信息");
dataSetReq.setAdmins(Lists.newArrayList("admin", "jack")); dataSetReq.setAdmins(Lists.newArrayList("admin", "jack"));
List<DataSetModelConfig> dataSetModelConfigs = getDataSetModelConfigs(singerDomain.getId()); List<DataSetModelConfig> dataSetModelConfigs = getDataSetModelConfigs(singerDomain.getId());
DataSetDetail dataSetDetail = new DataSetDetail(); DataSetDetail dataSetDetail = new DataSetDetail();
@@ -191,7 +191,7 @@ public class S2ArtistDemo extends S2BaseDemo {
agent.setDescription("帮助您用自然语言进行圈选,支持多条件组合筛选"); agent.setDescription("帮助您用自然语言进行圈选,支持多条件组合筛选");
agent.setStatus(1); agent.setStatus(1);
agent.setEnableSearch(1); agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("国风流派艺人", "港台地区的艺人", "流派为流行的艺人")); agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派"));
AgentConfig agentConfig = new AgentConfig(); AgentConfig agentConfig = new AgentConfig();
RuleParserTool ruleQueryTool = new RuleParserTool(); RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setId("0"); ruleQueryTool.setId("0");

View File

@@ -18,7 +18,7 @@ import org.springframework.boot.test.context.SpringBootTest;
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@Slf4j @Slf4j
public class TagTest extends BaseTest { public class DetailTest extends BaseTest {
@Test @Test
public void test_detail_dimension() throws Exception { public void test_detail_dimension() throws Exception {
@@ -72,7 +72,7 @@ public class TagTest extends BaseTest {
@Test @Test
public void test_detail_list_filter() throws Exception { public void test_detail_list_filter() throws Exception {
QueryResult actualResult = submitNewChat("国风艺人", DataUtils.tagAgentId); QueryResult actualResult = submitNewChat("国风歌手", DataUtils.tagAgentId);
QueryResult expectedResult = new QueryResult(); QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo(); SemanticParseInfo expectedParseInfo = new SemanticParseInfo();