mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(Chat) llmSqlParser is adapted for tag mode, and rule parsing filters based on the dataset query type. (#804)
This commit is contained in:
@@ -22,6 +22,9 @@ public class LLMParserConfig {
|
||||
@Value("${metric.topn:10}")
|
||||
private Integer metricTopN;
|
||||
|
||||
@Value("${tag.topn:20}")
|
||||
private Integer tagTopN;
|
||||
|
||||
@Value("${all.model:false}")
|
||||
private Boolean allModel;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
@@ -27,10 +25,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user);
|
||||
// 2.set queryType
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
||||
parseInfo.setQueryType(dataSetSchema.getQueryType());
|
||||
parseInfo.setQueryType(queryContext.getQueryType(parseInfo.getDataSetId()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
@@ -92,9 +93,8 @@ public class LLMRequestService {
|
||||
return llmParserTool.orElse(null);
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
|
||||
SemanticSchema semanticSchema, List<ElementValue> linkingValues) {
|
||||
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId, List<ElementValue> linkingValues) {
|
||||
Map<Long, String> dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName();
|
||||
String queryText = queryCtx.getQueryText();
|
||||
|
||||
LLMReq llmReq = new LLMReq();
|
||||
@@ -190,7 +190,8 @@ public class LLMRequestService {
|
||||
.filter(elementMatch -> !elementMatch.isInherited())
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType type = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
|
||||
return SchemaElementType.VALUE.equals(type) || SchemaElementType.TAG_VALUE.equals(type)
|
||||
|| SchemaElementType.ID.equals(type);
|
||||
})
|
||||
.map(elementMatch -> {
|
||||
ElementValue elementValue = new ElementValue();
|
||||
@@ -203,25 +204,38 @@ public class LLMRequestService {
|
||||
|
||||
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
return semanticSchema.getDimensions(dataSetId).stream()
|
||||
List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId);
|
||||
if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) {
|
||||
elements = semanticSchema.getTags(dataSetId);
|
||||
}
|
||||
return elements.stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
Set<String> results = semanticSchema.getDimensions(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getMetricTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
results.addAll(metrics);
|
||||
Set<String> results = new HashSet<>();
|
||||
if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) {
|
||||
Set<String> tags = semanticSchema.getTags(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
results.addAll(tags);
|
||||
} else {
|
||||
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
results.addAll(dimensions);
|
||||
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getMetricTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
results.addAll(metrics);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
@@ -236,12 +250,15 @@ public class LLMRequestService {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.METRIC.equals(elementType)
|
||||
|| SchemaElementType.DIMENSION.equals(elementType)
|
||||
|| SchemaElementType.VALUE.equals(elementType);
|
||||
|| SchemaElementType.VALUE.equals(elementType)
|
||||
|| SchemaElementType.TAG.equals(elementType)
|
||||
|| SchemaElementType.TAG_VALUE.equals(elementType);
|
||||
})
|
||||
.map(schemaElementMatch -> {
|
||||
SchemaElement element = schemaElementMatch.getElement();
|
||||
SchemaElementType elementType = element.getType();
|
||||
if (SchemaElementType.VALUE.equals(elementType)) {
|
||||
if (SchemaElementType.VALUE.equals(elementType) || SchemaElementType.TAG_VALUE.equals(
|
||||
elementType)) {
|
||||
return itemIdToName.get(element.getId());
|
||||
}
|
||||
return schemaElementMatch.getWord();
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
@@ -10,12 +9,11 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
|
||||
@Slf4j
|
||||
public class LLMSqlParser implements SemanticParser {
|
||||
@@ -41,8 +39,7 @@ public class LLMSqlParser implements SemanticParser {
|
||||
}
|
||||
//4.construct a request, call the API for the large model, and retrieve the results.
|
||||
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, linkingValues);
|
||||
LLMResp llmResp = requestService.requestLLM(llmReq, dataSetId);
|
||||
|
||||
if (Objects.isNull(llmResp)) {
|
||||
|
||||
@@ -8,14 +8,11 @@ import com.tencent.supersonic.chat.core.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class AgentCheckParser implements SemanticParser {
|
||||
@@ -46,18 +43,6 @@ public class AgentCheckParser implements SemanticParser {
|
||||
&& !tool.getQueryModes().contains(query.getQueryMode())) {
|
||||
return true;
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) {
|
||||
if (QueryManager.isTagQuery(query.getQueryMode())) {
|
||||
if (!tool.getQueryTypes().contains(QueryType.TAG.name())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (QueryManager.isMetricQuery(query.getQueryMode())) {
|
||||
if (!tool.getQueryTypes().contains(QueryType.METRIC.name())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (CollectionUtils.isEmpty(tool.getDataSetIds())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.core.pojo;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
@@ -10,17 +11,17 @@ import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@@ -58,4 +59,10 @@ public class QueryContext {
|
||||
.collect(Collectors.toList());
|
||||
return candidateQueries;
|
||||
}
|
||||
|
||||
public QueryType getQueryType(Long dataSetId) {
|
||||
SemanticSchema semanticSchema = this.semanticSchema;
|
||||
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
||||
return dataSetSchema.getQueryType();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,7 +180,10 @@
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-chroma</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-azure-open-ai</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.logging.log4j</groupId>
|
||||
<artifactId>log4j-api</artifactId>
|
||||
@@ -195,7 +198,6 @@
|
||||
<artifactId>hanlp</artifactId>
|
||||
<version>${hanlp.version}</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
|
||||
@@ -20,9 +20,10 @@ import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.chat.server.service.QueryService;
|
||||
import com.tencent.supersonic.common.pojo.SysParameter;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.service.SysParameterService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
@@ -32,9 +33,6 @@ import org.springframework.core.annotation.Order;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@Order(3)
|
||||
@@ -170,7 +168,6 @@ public class ChatDemoLoader implements CommandLineRunner {
|
||||
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
|
||||
ruleQueryTool.setId("0");
|
||||
ruleQueryTool.setDataSetIds(Lists.newArrayList(1L));
|
||||
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name()));
|
||||
agentConfig.getTools().add(ruleQueryTool);
|
||||
if (demoEnabledNl2SqlLlm) {
|
||||
LLMParserTool llmParserTool = new LLMParserTool();
|
||||
@@ -196,7 +193,6 @@ public class ChatDemoLoader implements CommandLineRunner {
|
||||
ruleQueryTool.setId("0");
|
||||
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
|
||||
ruleQueryTool.setDataSetIds(Lists.newArrayList(2L));
|
||||
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name()));
|
||||
agentConfig.getTools().add(ruleQueryTool);
|
||||
|
||||
if (demoEnabledNl2SqlLlm) {
|
||||
|
||||
@@ -58,7 +58,7 @@ public class TagTest extends BaseTest {
|
||||
List<String> list = new ArrayList<>();
|
||||
list.add("流行");
|
||||
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS,
|
||||
"流行", "风格", 6L);
|
||||
"流行", "风格", 2L);
|
||||
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
|
||||
SchemaElement metric = SchemaElement.builder().name("播放量").build();
|
||||
|
||||
5
pom.xml
5
pom.xml
@@ -149,6 +149,11 @@
|
||||
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-azure-open-ai</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user