mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 04:57:28 +00:00
(improvement)(Headless) Add Text2SQLType to control whether rules and large models are passed (#891)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -70,6 +70,10 @@ public class Agent extends RecordInfo {
|
|||||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
|
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean containsRuleTool() {
|
||||||
|
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
|
||||||
|
}
|
||||||
|
|
||||||
public boolean containsNL2SQLTool() {
|
public boolean containsNL2SQLTool() {
|
||||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM))
|
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM))
|
||||||
|| !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
|
|| !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.server.util;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.common.util.BeanMapper;
|
import com.tencent.supersonic.common.util.BeanMapper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||||
import org.apache.commons.collections.MapUtils;
|
import org.apache.commons.collections.MapUtils;
|
||||||
@@ -17,8 +18,12 @@ public class QueryReqConverter {
|
|||||||
if (agent == null) {
|
if (agent == null) {
|
||||||
return queryReq;
|
return queryReq;
|
||||||
}
|
}
|
||||||
if (agent.containsLLMParserTool()) {
|
if (agent.containsLLMParserTool() && agent.containsRuleTool()) {
|
||||||
queryReq.setEnableLLM(true);
|
queryReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
||||||
|
} else if (agent.containsLLMParserTool()) {
|
||||||
|
queryReq.setText2SQLType(Text2SQLType.ONLY_LLM);
|
||||||
|
} else if (agent.containsRuleTool()) {
|
||||||
|
queryReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||||
}
|
}
|
||||||
queryReq.setDataSetIds(agent.getDataSetIds());
|
queryReq.setDataSetIds(agent.getDataSetIds());
|
||||||
if (Objects.nonNull(queryReq.getMapInfo())
|
if (Objects.nonNull(queryReq.getMapInfo())
|
||||||
|
|||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package com.tencent.supersonic.common.pojo.enums;
|
||||||
|
|
||||||
|
public enum Text2SQLType {
|
||||||
|
|
||||||
|
ONLY_RULE, ONLY_LLM, RULE_AND_LLM;
|
||||||
|
|
||||||
|
public boolean enableRule() {
|
||||||
|
return this.equals(ONLY_RULE) || this.equals(RULE_AND_LLM);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean enableLLM() {
|
||||||
|
return this.equals(ONLY_LLM) || this.equals(RULE_AND_LLM);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
|||||||
|
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@@ -15,6 +16,6 @@ public class QueryReq {
|
|||||||
private User user;
|
private User user;
|
||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer = true;
|
||||||
private boolean enableLLM;
|
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,12 @@ import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
|||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper;
|
import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
@@ -24,12 +30,6 @@ import java.util.Map;
|
|||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Service
|
@Service
|
||||||
@@ -41,7 +41,7 @@ public class LLMRequestService {
|
|||||||
private OptimizationConfig optimizationConfig;
|
private OptimizationConfig optimizationConfig;
|
||||||
|
|
||||||
public boolean isSkip(QueryContext queryCtx) {
|
public boolean isSkip(QueryContext queryCtx) {
|
||||||
if (!queryCtx.isEnableLLM()) {
|
if (!queryCtx.getText2SQLType().enableLLM()) {
|
||||||
log.info("not enable llm, skip");
|
log.info("not enable llm, skip");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ import com.tencent.supersonic.headless.core.chat.parser.SemanticParser;
|
|||||||
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.headless.core.pojo.ChatContext;
|
import com.tencent.supersonic.headless.core.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance
|
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance
|
||||||
@@ -25,6 +25,10 @@ public class RuleSqlParser implements SemanticParser {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||||
|
if (!queryContext.getText2SQLType().enableRule()) {
|
||||||
|
log.info("not enable rule, skip");
|
||||||
|
return;
|
||||||
|
}
|
||||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||||
// iterate all schemaElementMatches to resolve query mode
|
// iterate all schemaElementMatches to resolve query mode
|
||||||
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
|
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
|
||||||
|
|||||||
@@ -2,22 +2,24 @@ package com.tencent.supersonic.headless.core.pojo;
|
|||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
|
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
|
||||||
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@@ -31,7 +33,7 @@ public class QueryContext {
|
|||||||
private Map<Long, List<Long>> modelIdToDataSetIds;
|
private Map<Long, List<Long>> modelIdToDataSetIds;
|
||||||
private User user;
|
private User user;
|
||||||
private boolean saveAnswer;
|
private boolean saveAnswer;
|
||||||
private boolean enableLLM;
|
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||||
|
|||||||
@@ -177,7 +177,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
.candidateQueries(new ArrayList<>())
|
.candidateQueries(new ArrayList<>())
|
||||||
.mapInfo(new SchemaMapInfo())
|
.mapInfo(new SchemaMapInfo())
|
||||||
.modelIdToDataSetIds(modelIdToDataSetIds)
|
.modelIdToDataSetIds(modelIdToDataSetIds)
|
||||||
.enableLLM(queryReq.isEnableLLM())
|
.text2SQLType(queryReq.getText2SQLType())
|
||||||
.build();
|
.build();
|
||||||
BeanUtils.copyProperties(queryReq, queryCtx);
|
BeanUtils.copyProperties(queryReq, queryCtx);
|
||||||
return queryCtx;
|
return queryCtx;
|
||||||
|
|||||||
Reference in New Issue
Block a user