mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) Modify query type rules in QueryTypeParser (#570)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -117,7 +117,8 @@ public class SemanticSchema implements Serializable {
|
|||||||
|
|
||||||
public List<SchemaElement> getTags(Set<Long> modelIds) {
|
public List<SchemaElement> getTags(Set<Long> modelIds) {
|
||||||
List<SchemaElement> tags = new ArrayList<>();
|
List<SchemaElement> tags = new ArrayList<>();
|
||||||
modelSchemaList.stream().filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
modelSchemaList.stream().filter(schemaElement ->
|
||||||
|
modelIds.contains(schemaElement.getModel().getModel()))
|
||||||
.forEach(d -> tags.addAll(d.getTags()));
|
.forEach(d -> tags.addAll(d.getTags()));
|
||||||
return tags;
|
return tags;
|
||||||
}
|
}
|
||||||
@@ -139,6 +140,11 @@ public class SemanticSchema implements Serializable {
|
|||||||
return entities;
|
return entities;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getEntities(Set<Long> modelIds) {
|
||||||
|
List<SchemaElement> entities = getEntities();
|
||||||
|
return getElementsByModelId(modelIds, entities);
|
||||||
|
}
|
||||||
|
|
||||||
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
|
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
|
||||||
return elements.stream()
|
return elements.stream()
|
||||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMSqlQuery;
|
|||||||
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.chat.service.SemanticService;
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
@@ -56,13 +57,20 @@ public class QueryTypeParser implements SemanticParser {
|
|||||||
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
||||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||||
//If all the fields in the SELECT statement are of tag type.
|
//If all the fields in the SELECT statement are of tag type.
|
||||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||||
|
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||||
|
.collect(Collectors.toList());
|
||||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||||
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||||
if (CollectionUtils.isNotEmpty(selectFields)) {
|
if (CollectionUtils.isNotEmpty(whereFields)) {
|
||||||
|
Set<String> ids = semanticSchema.getEntities(modelIds).stream().map(SchemaElement::getName)
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
|
||||||
|
return QueryType.ID;
|
||||||
|
}
|
||||||
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
|
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(selectFields)) {
|
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
|
||||||
return QueryType.TAG;
|
return QueryType.TAG;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user