(improvement)(Chat) llmSqlParser is adapted for tag mode, and rule parsing filters based on the dataset query type. (#804)

This commit is contained in:
lexluo09
2024-03-12 17:03:08 +08:00
committed by GitHub
parent c2316c944d
commit bcc0f9caa9
11 changed files with 72 additions and 66 deletions

View File

@@ -22,6 +22,9 @@ public class LLMParserConfig {
@Value("${metric.topn:10}")
private Integer metricTopN;
@Value("${tag.topn:20}")
private Integer tagTopN;
@Value("${all.model:false}")
private Boolean allModel;
}

View File

@@ -1,9 +1,7 @@
package com.tencent.supersonic.chat.core.parser;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
@@ -27,10 +25,7 @@ public class QueryTypeParser implements SemanticParser {
semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user);
// 2.set queryType
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Long dataSetId = parseInfo.getDataSetId();
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
parseInfo.setQueryType(dataSetSchema.getQueryType());
parseInfo.setQueryType(queryContext.getQueryType(parseInfo.getDataSetId()));
}
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.utils.S2SqlDateHelper;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
@@ -92,9 +93,8 @@ public class LLMRequestService {
return llmParserTool.orElse(null);
}
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
SemanticSchema semanticSchema, List<ElementValue> linkingValues) {
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId, List<ElementValue> linkingValues) {
Map<Long, String> dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName();
String queryText = queryCtx.getQueryText();
LLMReq llmReq = new LLMReq();
@@ -190,7 +190,8 @@ public class LLMRequestService {
.filter(elementMatch -> !elementMatch.isInherited())
.filter(schemaElementMatch -> {
SchemaElementType type = schemaElementMatch.getElement().getType();
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
return SchemaElementType.VALUE.equals(type) || SchemaElementType.TAG_VALUE.equals(type)
|| SchemaElementType.ID.equals(type);
})
.map(elementMatch -> {
ElementValue elementValue = new ElementValue();
@@ -203,25 +204,38 @@ public class LLMRequestService {
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
return semanticSchema.getDimensions(dataSetId).stream()
List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId);
if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) {
elements = semanticSchema.getTags(dataSetId);
}
return elements.stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
Set<String> results = semanticSchema.getDimensions(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
Set<String> results = new HashSet<>();
if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) {
Set<String> tags = semanticSchema.getTags(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(tags);
} else {
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(dimensions);
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
}
return results;
}
@@ -236,12 +250,15 @@ public class LLMRequestService {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType)
|| SchemaElementType.DIMENSION.equals(elementType)
|| SchemaElementType.VALUE.equals(elementType);
|| SchemaElementType.VALUE.equals(elementType)
|| SchemaElementType.TAG.equals(elementType)
|| SchemaElementType.TAG_VALUE.equals(elementType);
})
.map(schemaElementMatch -> {
SchemaElement element = schemaElementMatch.getElement();
SchemaElementType elementType = element.getType();
if (SchemaElementType.VALUE.equals(elementType)) {
if (SchemaElementType.VALUE.equals(elementType) || SchemaElementType.TAG_VALUE.equals(
elementType)) {
return itemIdToName.get(element.getId());
}
return schemaElementMatch.getWord();

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
@@ -10,12 +9,11 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
@Slf4j
public class LLMSqlParser implements SemanticParser {
@@ -41,8 +39,7 @@ public class LLMSqlParser implements SemanticParser {
}
//4.construct a request, call the API for the large model, and retrieve the results.
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, linkingValues);
LLMResp llmResp = requestService.requestLLM(llmReq, dataSetId);
if (Objects.isNull(llmResp)) {

View File

@@ -8,14 +8,11 @@ import com.tencent.supersonic.chat.core.agent.RuleParserTool;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class AgentCheckParser implements SemanticParser {
@@ -46,18 +43,6 @@ public class AgentCheckParser implements SemanticParser {
&& !tool.getQueryModes().contains(query.getQueryMode())) {
return true;
}
if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) {
if (QueryManager.isTagQuery(query.getQueryMode())) {
if (!tool.getQueryTypes().contains(QueryType.TAG.name())) {
return true;
}
}
if (QueryManager.isMetricQuery(query.getQueryMode())) {
if (!tool.getQueryTypes().contains(QueryType.METRIC.name())) {
return true;
}
}
}
if (CollectionUtils.isEmpty(tool.getDataSetIds())) {
return true;
}

View File

@@ -1,15 +1,14 @@
package com.tencent.supersonic.chat.core.parser.sql.rule;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;
import java.util.Arrays;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
/**
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.core.pojo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
@@ -10,17 +11,17 @@ import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@@ -58,4 +59,10 @@ public class QueryContext {
.collect(Collectors.toList());
return candidateQueries;
}
public QueryType getQueryType(Long dataSetId) {
SemanticSchema semanticSchema = this.semanticSchema;
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
return dataSetSchema.getQueryType();
}
}

View File

@@ -180,7 +180,10 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-chroma</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-azure-open-ai</artifactId>
</dependency>
<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-api</artifactId>
@@ -195,7 +198,6 @@
<artifactId>hanlp</artifactId>
<version>${hanlp.version}</version>
</dependency>
</dependencies>
</project>

View File

@@ -20,9 +20,10 @@ import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.chat.server.service.QueryService;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.Arrays;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
@@ -32,9 +33,6 @@ import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.List;
@Component
@Slf4j
@Order(3)
@@ -170,7 +168,6 @@ public class ChatDemoLoader implements CommandLineRunner {
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0");
ruleQueryTool.setDataSetIds(Lists.newArrayList(1L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name()));
agentConfig.getTools().add(ruleQueryTool);
if (demoEnabledNl2SqlLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
@@ -196,7 +193,6 @@ public class ChatDemoLoader implements CommandLineRunner {
ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setDataSetIds(Lists.newArrayList(2L));
ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name()));
agentConfig.getTools().add(ruleQueryTool);
if (demoEnabledNl2SqlLlm) {

View File

@@ -58,7 +58,7 @@ public class TagTest extends BaseTest {
List<String> list = new ArrayList<>();
list.add("流行");
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS,
"流行", "风格", 6L);
"流行", "风格", 2L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
SchemaElement metric = SchemaElement.builder().name("播放量").build();

View File

@@ -149,6 +149,11 @@
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-azure-open-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>
</dependencies>
</dependencyManagement>