(improvement)(Headless) Add Text2SQLType to control whether rules and large models are passed (#891)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-04-07 14:33:17 +08:00
committed by GitHub
parent 12e25c0c50
commit 5f6e9ae194
8 changed files with 48 additions and 17 deletions

View File

@@ -70,6 +70,10 @@ public class Agent extends RecordInfo {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM)); return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
} }
public boolean containsRuleTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
}
public boolean containsNL2SQLTool() { public boolean containsNL2SQLTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM)) return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM))
|| !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE)); || !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import org.apache.commons.collections.MapUtils; import org.apache.commons.collections.MapUtils;
@@ -17,8 +18,12 @@ public class QueryReqConverter {
if (agent == null) { if (agent == null) {
return queryReq; return queryReq;
} }
if (agent.containsLLMParserTool()) { if (agent.containsLLMParserTool() && agent.containsRuleTool()) {
queryReq.setEnableLLM(true); queryReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
} else if (agent.containsLLMParserTool()) {
queryReq.setText2SQLType(Text2SQLType.ONLY_LLM);
} else if (agent.containsRuleTool()) {
queryReq.setText2SQLType(Text2SQLType.ONLY_RULE);
} }
queryReq.setDataSetIds(agent.getDataSetIds()); queryReq.setDataSetIds(agent.getDataSetIds());
if (Objects.nonNull(queryReq.getMapInfo()) if (Objects.nonNull(queryReq.getMapInfo())

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.common.pojo.enums;
public enum Text2SQLType {
ONLY_RULE, ONLY_LLM, RULE_AND_LLM;
public boolean enableRule() {
return this.equals(ONLY_RULE) || this.equals(RULE_AND_LLM);
}
public boolean enableLLM() {
return this.equals(ONLY_LLM) || this.equals(RULE_AND_LLM);
}
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import lombok.Data; import lombok.Data;
@@ -15,6 +16,6 @@ public class QueryReq {
private User user; private User user;
private QueryFilters queryFilters; private QueryFilters queryFilters;
private boolean saveAnswer = true; private boolean saveAnswer = true;
private boolean enableLLM; private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SchemaMapInfo mapInfo = new SchemaMapInfo();
} }

View File

@@ -16,6 +16,12 @@ import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.utils.ComponentFactory; import com.tencent.supersonic.headless.core.utils.ComponentFactory;
import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper; import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashSet; import java.util.HashSet;
@@ -24,12 +30,6 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
@Service @Service
@@ -41,7 +41,7 @@ public class LLMRequestService {
private OptimizationConfig optimizationConfig; private OptimizationConfig optimizationConfig;
public boolean isSkip(QueryContext queryCtx) { public boolean isSkip(QueryContext queryCtx) {
if (!queryCtx.isEnableLLM()) { if (!queryCtx.getText2SQLType().enableLLM()) {
log.info("not enable llm, skip"); log.info("not enable llm, skip");
return true; return true;
} }

View File

@@ -6,9 +6,9 @@ import com.tencent.supersonic.headless.core.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.core.pojo.ChatContext; import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import lombok.extern.slf4j.Slf4j;
/** /**
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance * RuleSqlParser resolves a specific SemanticQuery according to co-appearance
@@ -25,6 +25,10 @@ public class RuleSqlParser implements SemanticParser {
@Override @Override
public void parse(QueryContext queryContext, ChatContext chatContext) { public void parse(QueryContext queryContext, ChatContext chatContext) {
if (!queryContext.getText2SQLType().enableRule()) {
log.info("not enable rule, skip");
return;
}
SchemaMapInfo mapInfo = queryContext.getMapInfo(); SchemaMapInfo mapInfo = queryContext.getMapInfo();
// iterate all schemaElementMatches to resolve query mode // iterate all schemaElementMatches to resolve query mode
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) { for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {

View File

@@ -2,22 +2,24 @@ package com.tencent.supersonic.headless.core.pojo;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery; import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data @Data
@Builder @Builder
@@ -31,7 +33,7 @@ public class QueryContext {
private Map<Long, List<Long>> modelIdToDataSetIds; private Map<Long, List<Long>> modelIdToDataSetIds;
private User user; private User user;
private boolean saveAnswer; private boolean saveAnswer;
private boolean enableLLM; private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
private QueryFilters queryFilters; private QueryFilters queryFilters;
private List<SemanticQuery> candidateQueries = new ArrayList<>(); private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SchemaMapInfo mapInfo = new SchemaMapInfo();

View File

@@ -177,7 +177,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
.candidateQueries(new ArrayList<>()) .candidateQueries(new ArrayList<>())
.mapInfo(new SchemaMapInfo()) .mapInfo(new SchemaMapInfo())
.modelIdToDataSetIds(modelIdToDataSetIds) .modelIdToDataSetIds(modelIdToDataSetIds)
.enableLLM(queryReq.isEnableLLM()) .text2SQLType(queryReq.getText2SQLType())
.build(); .build();
BeanUtils.copyProperties(queryReq, queryCtx); BeanUtils.copyProperties(queryReq, queryCtx);
return queryCtx; return queryCtx;