mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) Modify query type rules in QueryTypeParser (#570)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -117,7 +117,8 @@ public class SemanticSchema implements Serializable {
|
||||
|
||||
public List<SchemaElement> getTags(Set<Long> modelIds) {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
modelSchemaList.stream().filter(schemaElement ->
|
||||
modelIds.contains(schemaElement.getModel().getModel()))
|
||||
.forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
@@ -139,6 +140,11 @@ public class SemanticSchema implements Serializable {
|
||||
return entities;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getEntities(Set<Long> modelIds) {
|
||||
List<SchemaElement> entities = getEntities();
|
||||
return getElementsByModelId(modelIds, entities);
|
||||
}
|
||||
|
||||
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
|
||||
@@ -13,6 +13,7 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
@@ -56,13 +57,20 @@ public class QueryTypeParser implements SemanticParser {
|
||||
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||
if (CollectionUtils.isNotEmpty(selectFields)) {
|
||||
if (CollectionUtils.isNotEmpty(whereFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(modelIds).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
|
||||
return QueryType.ID;
|
||||
}
|
||||
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(selectFields)) {
|
||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
|
||||
return QueryType.TAG;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user