From 40c86810bbf06bbf6a2cbc3fbd0295b6c772d687 Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Mon, 25 Dec 2023 20:54:04 +0800 Subject: [PATCH] (improvement)(chat) Modify query type rules in QueryTypeParser (#570) Co-authored-by: jolunoluo --- .../supersonic/chat/api/pojo/SemanticSchema.java | 8 +++++++- .../supersonic/chat/parser/QueryTypeParser.java | 14 +++++++++++--- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java index dda6ee609..f8d1f16d3 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java @@ -117,7 +117,8 @@ public class SemanticSchema implements Serializable { public List getTags(Set modelIds) { List tags = new ArrayList<>(); - modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel())) + modelSchemaList.stream().filter(schemaElement -> + modelIds.contains(schemaElement.getModel().getModel())) .forEach(d -> tags.addAll(d.getTags())); return tags; } @@ -139,6 +140,11 @@ public class SemanticSchema implements Serializable { return entities; } + public List getEntities(Set modelIds) { + List entities = getEntities(); + return getElementsByModelId(modelIds, entities); + } + private List getElementsByModelId(Set modelIds, List elements) { return elements.stream() .filter(schemaElement -> modelIds.contains(schemaElement.getModel())) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/QueryTypeParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/QueryTypeParser.java index e351c8066..95ba30cf3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/QueryTypeParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/QueryTypeParser.java @@ -13,6 +13,7 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.pojo.enums.QueryType; +import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.knowledge.service.SchemaService; @@ -56,13 +57,20 @@ public class QueryTypeParser implements SemanticParser { Set modelIds = parseInfo.getModel().getModelIds(); if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) { //If all the fields in the SELECT statement are of tag type. - List selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL()); + List whereFields = SqlParserSelectHelper.getWhereFields(sqlInfo.getS2SQL()) + .stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field)) + .collect(Collectors.toList()); SemanticService semanticService = ContextUtils.getBean(SemanticService.class); SemanticSchema semanticSchema = semanticService.getSemanticSchema(); - if (CollectionUtils.isNotEmpty(selectFields)) { + if (CollectionUtils.isNotEmpty(whereFields)) { + Set ids = semanticSchema.getEntities(modelIds).stream().map(SchemaElement::getName) + .collect(Collectors.toSet()); + if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) { + return QueryType.ID; + } Set tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName) .collect(Collectors.toSet()); - if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(selectFields)) { + if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) { return QueryType.TAG; } }