mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(fix)(headless)Fix demo conversations with DETAIL query mode.
(fix)(headless)Fix demo conversations with DETAIL query mode.
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.chat.corrector;
|
package com.tencent.supersonic.headless.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||||
@@ -33,6 +34,10 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private Boolean needAddGroupBy(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
private Boolean needAddGroupBy(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||||
//add dimension group by
|
//add dimension group by
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.chat.query.QueryManager;
|
|||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||||
|
import com.tencent.supersonic.headless.chat.query.rule.detail.DetailDimensionQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery;
|
||||||
@@ -103,7 +104,8 @@ public class ContextInheritParser implements SemanticParser {
|
|||||||
protected boolean shouldInherit(ChatQueryContext chatQueryContext) {
|
protected boolean shouldInherit(ChatQueryContext chatQueryContext) {
|
||||||
// if candidates only have MetricModel mode, count in context
|
// if candidates only have MetricModel mode, count in context
|
||||||
List<SemanticQuery> metricModelQueries = chatQueryContext.getCandidateQueries().stream()
|
List<SemanticQuery> metricModelQueries = chatQueryContext.getCandidateQueries().stream()
|
||||||
.filter(query -> query instanceof MetricModelQuery).collect(
|
.filter(query -> query instanceof MetricModelQuery
|
||||||
|
|| query instanceof DetailDimensionQuery).collect(
|
||||||
Collectors.toList());
|
Collectors.toList());
|
||||||
return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size();
|
return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ public class QueryManager {
|
|||||||
return ruleQueryMap.get(queryMode) instanceof MetricSemanticQuery;
|
return ruleQueryMap.get(queryMode) instanceof MetricSemanticQuery;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static boolean isTagQuery(String queryMode) {
|
public static boolean isDetailQuery(String queryMode) {
|
||||||
if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) {
|
if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,8 +75,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((QueryManager.isTagQuery(queryParseInfo.getQueryMode())
|
if ((QueryManager.isDetailQuery(queryParseInfo.getQueryMode())
|
||||||
&& QueryManager.isTagQuery(contextParseInfo.getQueryMode()))
|
&& QueryManager.isDetailQuery(contextParseInfo.getQueryMode()))
|
||||||
|| (QueryManager.isMetricQuery(queryParseInfo.getQueryMode())
|
|| (QueryManager.isMetricQuery(queryParseInfo.getQueryMode())
|
||||||
&& QueryManager.isMetricQuery(contextParseInfo.getQueryMode()))) {
|
&& QueryManager.isMetricQuery(contextParseInfo.getQueryMode()))) {
|
||||||
// inherit date info from context
|
// inherit date info from context
|
||||||
|
|||||||
@@ -34,12 +34,13 @@ public class JdbcExecutor implements QueryExecutor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
SqlUtils sqlUtils = ContextUtils.getBean(SqlUtils.class);
|
SqlUtils sqlUtils = ContextUtils.getBean(SqlUtils.class);
|
||||||
log.info("executing SQL: {}", StringUtils.normalizeSpace(queryStatement.getSql()));
|
String sql = StringUtils.normalizeSpace(queryStatement.getSql());
|
||||||
|
log.info("executing SQL: {}", sql);
|
||||||
Database database = queryStatement.getSemanticModel().getDatabase();
|
Database database = queryStatement.getSemanticModel().getDatabase();
|
||||||
SemanticQueryResp queryResultWithColumns = new SemanticQueryResp();
|
SemanticQueryResp queryResultWithColumns = new SemanticQueryResp();
|
||||||
SqlUtils sqlUtil = sqlUtils.init(database);
|
SqlUtils sqlUtil = sqlUtils.init(database);
|
||||||
sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns);
|
sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns);
|
||||||
queryResultWithColumns.setSql(queryStatement.getSql());
|
queryResultWithColumns.setSql(sql);
|
||||||
return queryResultWithColumns;
|
return queryResultWithColumns;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ public class EntityInfoProcessor implements ResultProcessor {
|
|||||||
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) {
|
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) {
|
||||||
parseResp.getSelectedParses().forEach(parseInfo -> {
|
parseResp.getSelectedParses().forEach(parseInfo -> {
|
||||||
String queryMode = parseInfo.getQueryMode();
|
String queryMode = parseInfo.getQueryMode();
|
||||||
if (!QueryManager.isTagQuery(queryMode) && !QueryManager.isMetricQuery(queryMode)) {
|
if (!QueryManager.isDetailQuery(queryMode) && !QueryManager.isMetricQuery(queryMode)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ public class S2ArtistDemo extends S2BaseDemo {
|
|||||||
private TagObjectResp addTagObjectSinger(DomainResp singerDomain) throws Exception {
|
private TagObjectResp addTagObjectSinger(DomainResp singerDomain) throws Exception {
|
||||||
TagObjectReq tagObjectReq = new TagObjectReq();
|
TagObjectReq tagObjectReq = new TagObjectReq();
|
||||||
tagObjectReq.setDomainId(singerDomain.getId());
|
tagObjectReq.setDomainId(singerDomain.getId());
|
||||||
tagObjectReq.setName("艺人");
|
tagObjectReq.setName("歌手");
|
||||||
tagObjectReq.setBizName("singer");
|
tagObjectReq.setBizName("singer");
|
||||||
User user = User.getFakeUser();
|
User user = User.getFakeUser();
|
||||||
return tagObjectService.create(tagObjectReq, user);
|
return tagObjectService.create(tagObjectReq, user);
|
||||||
@@ -85,7 +85,7 @@ public class S2ArtistDemo extends S2BaseDemo {
|
|||||||
|
|
||||||
public DomainResp addDomain() {
|
public DomainResp addDomain() {
|
||||||
DomainReq domainReq = new DomainReq();
|
DomainReq domainReq = new DomainReq();
|
||||||
domainReq.setName("艺人库");
|
domainReq.setName("歌手库");
|
||||||
domainReq.setBizName("singer");
|
domainReq.setBizName("singer");
|
||||||
domainReq.setParentId(0L);
|
domainReq.setParentId(0L);
|
||||||
domainReq.setStatus(StatusEnum.ONLINE.getCode());
|
domainReq.setStatus(StatusEnum.ONLINE.getCode());
|
||||||
@@ -100,9 +100,9 @@ public class S2ArtistDemo extends S2BaseDemo {
|
|||||||
public ModelResp addModel(DomainResp singerDomain,
|
public ModelResp addModel(DomainResp singerDomain,
|
||||||
DatabaseResp s2Database, TagObjectResp singerTagObject) throws Exception {
|
DatabaseResp s2Database, TagObjectResp singerTagObject) throws Exception {
|
||||||
ModelReq modelReq = new ModelReq();
|
ModelReq modelReq = new ModelReq();
|
||||||
modelReq.setName("艺人库");
|
modelReq.setName("歌手库");
|
||||||
modelReq.setBizName("singer");
|
modelReq.setBizName("singer");
|
||||||
modelReq.setDescription("艺人库");
|
modelReq.setDescription("歌手库");
|
||||||
modelReq.setDatabaseId(s2Database.getId());
|
modelReq.setDatabaseId(s2Database.getId());
|
||||||
modelReq.setDomainId(singerDomain.getId());
|
modelReq.setDomainId(singerDomain.getId());
|
||||||
modelReq.setTagObjectId(singerTagObject.getId());
|
modelReq.setTagObjectId(singerTagObject.getId());
|
||||||
@@ -113,7 +113,7 @@ public class S2ArtistDemo extends S2BaseDemo {
|
|||||||
ModelDetail modelDetail = new ModelDetail();
|
ModelDetail modelDetail = new ModelDetail();
|
||||||
List<Identify> identifiers = new ArrayList<>();
|
List<Identify> identifiers = new ArrayList<>();
|
||||||
Identify identify = new Identify("歌手名", IdentifyType.primary.name(), "singer_name", 1);
|
Identify identify = new Identify("歌手名", IdentifyType.primary.name(), "singer_name", 1);
|
||||||
identify.setEntityNames(Lists.newArrayList("歌手", "艺人"));
|
identify.setEntityNames(Lists.newArrayList("歌手"));
|
||||||
identifiers.add(identify);
|
identifiers.add(identify);
|
||||||
modelDetail.setIdentifiers(identifiers);
|
modelDetail.setIdentifiers(identifiers);
|
||||||
|
|
||||||
@@ -152,10 +152,10 @@ public class S2ArtistDemo extends S2BaseDemo {
|
|||||||
|
|
||||||
public long addDataSet(DomainResp singerDomain, ModelResp singerModel) {
|
public long addDataSet(DomainResp singerDomain, ModelResp singerModel) {
|
||||||
DataSetReq dataSetReq = new DataSetReq();
|
DataSetReq dataSetReq = new DataSetReq();
|
||||||
dataSetReq.setName("艺人库数据集");
|
dataSetReq.setName("歌手库数据集");
|
||||||
dataSetReq.setBizName("singer");
|
dataSetReq.setBizName("singer");
|
||||||
dataSetReq.setDomainId(singerDomain.getId());
|
dataSetReq.setDomainId(singerDomain.getId());
|
||||||
dataSetReq.setDescription("包含艺人相关标签和指标信息");
|
dataSetReq.setDescription("包含歌手相关标签和指标信息");
|
||||||
dataSetReq.setAdmins(Lists.newArrayList("admin", "jack"));
|
dataSetReq.setAdmins(Lists.newArrayList("admin", "jack"));
|
||||||
List<DataSetModelConfig> dataSetModelConfigs = getDataSetModelConfigs(singerDomain.getId());
|
List<DataSetModelConfig> dataSetModelConfigs = getDataSetModelConfigs(singerDomain.getId());
|
||||||
DataSetDetail dataSetDetail = new DataSetDetail();
|
DataSetDetail dataSetDetail = new DataSetDetail();
|
||||||
@@ -191,7 +191,7 @@ public class S2ArtistDemo extends S2BaseDemo {
|
|||||||
agent.setDescription("帮助您用自然语言进行圈选,支持多条件组合筛选");
|
agent.setDescription("帮助您用自然语言进行圈选,支持多条件组合筛选");
|
||||||
agent.setStatus(1);
|
agent.setStatus(1);
|
||||||
agent.setEnableSearch(1);
|
agent.setEnableSearch(1);
|
||||||
agent.setExamples(Lists.newArrayList("国风流派艺人", "港台地区的艺人", "流派为流行的艺人"));
|
agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派"));
|
||||||
AgentConfig agentConfig = new AgentConfig();
|
AgentConfig agentConfig = new AgentConfig();
|
||||||
RuleParserTool ruleQueryTool = new RuleParserTool();
|
RuleParserTool ruleQueryTool = new RuleParserTool();
|
||||||
ruleQueryTool.setId("0");
|
ruleQueryTool.setId("0");
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import org.springframework.boot.test.context.SpringBootTest;
|
|||||||
|
|
||||||
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
|
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TagTest extends BaseTest {
|
public class DetailTest extends BaseTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_detail_dimension() throws Exception {
|
public void test_detail_dimension() throws Exception {
|
||||||
@@ -72,7 +72,7 @@ public class TagTest extends BaseTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_detail_list_filter() throws Exception {
|
public void test_detail_list_filter() throws Exception {
|
||||||
QueryResult actualResult = submitNewChat("国风艺人", DataUtils.tagAgentId);
|
QueryResult actualResult = submitNewChat("国风歌手", DataUtils.tagAgentId);
|
||||||
|
|
||||||
QueryResult expectedResult = new QueryResult();
|
QueryResult expectedResult = new QueryResult();
|
||||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||||
Reference in New Issue
Block a user