(improvement)(Chat) llmSqlParser is adapted for tag mode, and rule parsing filters based on the dataset query type. (#804)

This commit is contained in:
lexluo09
2024-03-12 17:03:08 +08:00
committed by GitHub
parent c2316c944d
commit bcc0f9caa9
11 changed files with 72 additions and 66 deletions

View File

@@ -22,6 +22,9 @@ public class LLMParserConfig {
@Value("${metric.topn:10}") @Value("${metric.topn:10}")
private Integer metricTopN; private Integer metricTopN;
@Value("${tag.topn:20}")
private Integer tagTopN;
@Value("${all.model:false}") @Value("${all.model:false}")
private Boolean allModel; private Boolean allModel;
} }

View File

@@ -1,9 +1,7 @@
package com.tencent.supersonic.chat.core.parser; package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.pojo.ChatContext; import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery; import com.tencent.supersonic.chat.core.query.SemanticQuery;
@@ -27,10 +25,7 @@ public class QueryTypeParser implements SemanticParser {
semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user); semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user);
// 2.set queryType // 2.set queryType
SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SemanticSchema semanticSchema = queryContext.getSemanticSchema(); parseInfo.setQueryType(queryContext.getQueryType(parseInfo.getDataSetId()));
Long dataSetId = parseInfo.getDataSetId();
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
parseInfo.setQueryType(dataSetSchema.getQueryType());
} }
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper; import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
@@ -92,9 +93,8 @@ public class LLMRequestService {
return llmParserTool.orElse(null); return llmParserTool.orElse(null);
} }
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId, public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId, List<ElementValue> linkingValues) {
SemanticSchema semanticSchema, List<ElementValue> linkingValues) { Map<Long, String> dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName();
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
String queryText = queryCtx.getQueryText(); String queryText = queryCtx.getQueryText();
LLMReq llmReq = new LLMReq(); LLMReq llmReq = new LLMReq();
@@ -190,7 +190,8 @@ public class LLMRequestService {
.filter(elementMatch -> !elementMatch.isInherited()) .filter(elementMatch -> !elementMatch.isInherited())
.filter(schemaElementMatch -> { .filter(schemaElementMatch -> {
SchemaElementType type = schemaElementMatch.getElement().getType(); SchemaElementType type = schemaElementMatch.getElement().getType();
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type); return SchemaElementType.VALUE.equals(type) || SchemaElementType.TAG_VALUE.equals(type)
|| SchemaElementType.ID.equals(type);
}) })
.map(elementMatch -> { .map(elementMatch -> {
ElementValue elementValue = new ElementValue(); ElementValue elementValue = new ElementValue();
@@ -203,25 +204,38 @@ public class LLMRequestService {
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long dataSetId) { protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
return semanticSchema.getDimensions(dataSetId).stream() List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId);
if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) {
elements = semanticSchema.getTags(dataSetId);
}
return elements.stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
} }
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) { private Set<String> getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
Set<String> results = semanticSchema.getDimensions(dataSetId).stream() Set<String> results = new HashSet<>();
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) {
.limit(llmParserConfig.getDimensionTopN()) Set<String> tags = semanticSchema.getTags(dataSetId).stream()
.map(entry -> entry.getName()) .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.collect(Collectors.toSet()); .limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream() .collect(Collectors.toSet());
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) results.addAll(tags);
.limit(llmParserConfig.getMetricTopN()) } else {
.map(entry -> entry.getName()) Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
.collect(Collectors.toSet()); .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
results.addAll(metrics); .map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(dimensions);
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
}
return results; return results;
} }
@@ -236,12 +250,15 @@ public class LLMRequestService {
SchemaElementType elementType = schemaElementMatch.getElement().getType(); SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType) return SchemaElementType.METRIC.equals(elementType)
|| SchemaElementType.DIMENSION.equals(elementType) || SchemaElementType.DIMENSION.equals(elementType)
|| SchemaElementType.VALUE.equals(elementType); || SchemaElementType.VALUE.equals(elementType)
|| SchemaElementType.TAG.equals(elementType)
|| SchemaElementType.TAG_VALUE.equals(elementType);
}) })
.map(schemaElementMatch -> { .map(schemaElementMatch -> {
SchemaElement element = schemaElementMatch.getElement(); SchemaElement element = schemaElementMatch.getElement();
SchemaElementType elementType = element.getType(); SchemaElementType elementType = element.getType();
if (SchemaElementType.VALUE.equals(elementType)) { if (SchemaElementType.VALUE.equals(elementType) || SchemaElementType.TAG_VALUE.equals(
elementType)) {
return itemIdToName.get(element.getId()); return itemIdToName.get(element.getId());
} }
return schemaElementMatch.getWord(); return schemaElementMatch.getWord();

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.chat.core.parser.sql.llm; package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.agent.NL2SQLTool; import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
import com.tencent.supersonic.chat.core.parser.SemanticParser; import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext; import com.tencent.supersonic.chat.core.pojo.ChatContext;
@@ -10,12 +9,11 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp; import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
@Slf4j @Slf4j
public class LLMSqlParser implements SemanticParser { public class LLMSqlParser implements SemanticParser {
@@ -41,8 +39,7 @@ public class LLMSqlParser implements SemanticParser {
} }
//4.construct a request, call the API for the large model, and retrieve the results. //4.construct a request, call the API for the large model, and retrieve the results.
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId); List<ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, linkingValues);
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
LLMResp llmResp = requestService.requestLLM(llmReq, dataSetId); LLMResp llmResp = requestService.requestLLM(llmReq, dataSetId);
if (Objects.isNull(llmResp)) { if (Objects.isNull(llmResp)) {

View File

@@ -8,14 +8,11 @@ import com.tencent.supersonic.chat.core.agent.RuleParserTool;
import com.tencent.supersonic.chat.core.parser.SemanticParser; import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext; import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery; import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j @Slf4j
public class AgentCheckParser implements SemanticParser { public class AgentCheckParser implements SemanticParser {
@@ -46,18 +43,6 @@ public class AgentCheckParser implements SemanticParser {
&& !tool.getQueryModes().contains(query.getQueryMode())) { && !tool.getQueryModes().contains(query.getQueryMode())) {
return true; return true;
} }
if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) {
if (QueryManager.isTagQuery(query.getQueryMode())) {
if (!tool.getQueryTypes().contains(QueryType.TAG.name())) {
return true;
}
}
if (QueryManager.isMetricQuery(query.getQueryMode())) {
if (!tool.getQueryTypes().contains(QueryType.METRIC.name())) {
return true;
}
}
}
if (CollectionUtils.isEmpty(tool.getDataSetIds())) { if (CollectionUtils.isEmpty(tool.getDataSetIds())) {
return true; return true;
} }

View File

@@ -1,15 +1,14 @@
package com.tencent.supersonic.chat.core.parser.sql.rule; package com.tencent.supersonic.chat.core.parser.sql.rule;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.core.parser.SemanticParser; import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext; import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery; import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import lombok.extern.slf4j.Slf4j;
/** /**
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance * RuleSqlParser resolves a specific SemanticQuery according to co-appearance

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.core.pojo;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
@@ -10,17 +11,17 @@ import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.config.OptimizationConfig; import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.plugin.Plugin; import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.query.SemanticQuery; import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data @Data
@Builder @Builder
@@ -58,4 +59,10 @@ public class QueryContext {
.collect(Collectors.toList()); .collect(Collectors.toList());
return candidateQueries; return candidateQueries;
} }
public QueryType getQueryType(Long dataSetId) {
SemanticSchema semanticSchema = this.semanticSchema;
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
return dataSetSchema.getQueryType();
}
} }

View File

@@ -180,7 +180,10 @@
<groupId>dev.langchain4j</groupId> <groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-chroma</artifactId> <artifactId>langchain4j-chroma</artifactId>
</dependency> </dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-azure-open-ai</artifactId>
</dependency>
<dependency> <dependency>
<groupId>org.apache.logging.log4j</groupId> <groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-api</artifactId> <artifactId>log4j-api</artifactId>
@@ -195,7 +198,6 @@
<artifactId>hanlp</artifactId> <artifactId>hanlp</artifactId>
<version>${hanlp.version}</version> <version>${hanlp.version}</version>
</dependency> </dependency>
</dependencies> </dependencies>
</project> </project>

View File

@@ -20,9 +20,10 @@ import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.PluginService; import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.chat.server.service.QueryService; import com.tencent.supersonic.chat.server.service.QueryService;
import com.tencent.supersonic.common.pojo.SysParameter; import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.service.SysParameterService; import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import java.util.Arrays;
import java.util.List;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Qualifier;
@@ -32,9 +33,6 @@ import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.List;
@Component @Component
@Slf4j @Slf4j
@Order(3) @Order(3)
@@ -170,7 +168,6 @@ public class ChatDemoLoader implements CommandLineRunner {
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0"); ruleQueryTool.setId("0");
ruleQueryTool.setDataSetIds(Lists.newArrayList(1L)); ruleQueryTool.setDataSetIds(Lists.newArrayList(1L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name()));
agentConfig.getTools().add(ruleQueryTool); agentConfig.getTools().add(ruleQueryTool);
if (demoEnabledNl2SqlLlm) { if (demoEnabledNl2SqlLlm) {
LLMParserTool llmParserTool = new LLMParserTool(); LLMParserTool llmParserTool = new LLMParserTool();
@@ -196,7 +193,6 @@ public class ChatDemoLoader implements CommandLineRunner {
ruleQueryTool.setId("0"); ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setDataSetIds(Lists.newArrayList(2L)); ruleQueryTool.setDataSetIds(Lists.newArrayList(2L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name()));
agentConfig.getTools().add(ruleQueryTool); agentConfig.getTools().add(ruleQueryTool);
if (demoEnabledNl2SqlLlm) { if (demoEnabledNl2SqlLlm) {

View File

@@ -58,7 +58,7 @@ public class TagTest extends BaseTest {
List<String> list = new ArrayList<>(); List<String> list = new ArrayList<>();
list.add("流行"); list.add("流行");
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS,
"流行", "风格", 6L); "流行", "风格", 2L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter); expectedParseInfo.getDimensionFilters().add(dimensionFilter);
SchemaElement metric = SchemaElement.builder().name("播放量").build(); SchemaElement metric = SchemaElement.builder().name("播放量").build();

View File

@@ -149,6 +149,11 @@
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId> <artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
<version>${langchain4j.version}</version> <version>${langchain4j.version}</version>
</dependency> </dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-azure-open-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>
</dependencies> </dependencies>
</dependencyManagement> </dependencyManagement>