(improvement)(chat) Modify query type rules in QueryTypeParser (#570)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2023-12-25 20:54:04 +08:00
committed by GitHub
parent 5ab1cade0a
commit 40c86810bb
2 changed files with 18 additions and 4 deletions

View File

@@ -117,7 +117,8 @@ public class SemanticSchema implements Serializable {
public List<SchemaElement> getTags(Set<Long> modelIds) {
List<SchemaElement> tags = new ArrayList<>();
modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
modelSchemaList.stream().filter(schemaElement ->
modelIds.contains(schemaElement.getModel().getModel()))
.forEach(d -> tags.addAll(d.getTags()));
return tags;
}
@@ -139,6 +140,11 @@ public class SemanticSchema implements Serializable {
return entities;
}
public List<SchemaElement> getEntities(Set<Long> modelIds) {
List<SchemaElement> entities = getEntities();
return getElementsByModelId(modelIds, entities);
}
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))

View File

@@ -13,6 +13,7 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
@@ -56,13 +57,20 @@ public class QueryTypeParser implements SemanticParser {
Set<Long> modelIds = parseInfo.getModel().getModelIds();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
//If all the fields in the SELECT statement are of tag type.
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sqlInfo.getS2SQL())
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
if (CollectionUtils.isNotEmpty(selectFields)) {
if (CollectionUtils.isNotEmpty(whereFields)) {
Set<String> ids = semanticSchema.getEntities(modelIds).stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
return QueryType.ID;
}
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(selectFields)) {
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
return QueryType.TAG;
}
}