mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement][project]Simplify code logic in multiple modules.
This commit is contained in:
@@ -26,7 +26,7 @@ import java.util.stream.Collectors;
|
||||
public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
|
||||
public static final String APP_KEY = "SMALL_TALK";
|
||||
private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to."
|
||||
private static final String INSTRUCTION = "#Role: You are a nice person to talk to."
|
||||
+ "\n#Task: Respond quickly and nicely to the user."
|
||||
+ "\n#Rules: 1.ALWAYS use the same language as the `#Current Input`."
|
||||
+ "\n#History Inputs: %s" + "\n#Current Input: %s" + "\n#Response: ";
|
||||
|
||||
@@ -83,6 +83,9 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
if (Objects.isNull(parseContext.getRequest().getSelectedParse())) {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||
if (parseContext.enableLLM()) {
|
||||
queryNLReq.setText2SQLType(Text2SQLType.NONE);
|
||||
}
|
||||
|
||||
// for every requested dataSet, recursively invoke rule-based parser with different
|
||||
// mapModes
|
||||
|
||||
@@ -7,7 +7,7 @@ import java.util.Map;
|
||||
|
||||
public class PluginQueryManager {
|
||||
|
||||
private static Map<String, PluginSemanticQuery> pluginQueries = new HashMap<>();
|
||||
private static final Map<String, PluginSemanticQuery> pluginQueries = new HashMap<>();
|
||||
|
||||
public static void register(String queryMode, PluginSemanticQuery pluginSemanticQuery) {
|
||||
pluginQueries.put(queryMode, pluginSemanticQuery);
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
@@ -22,7 +21,12 @@ import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
@@ -33,8 +37,7 @@ public class ParseInfoFormatProcessor implements ParseResultProcessor {
|
||||
@Override
|
||||
public void process(ParseContext parseContext) {
|
||||
parseContext.getResponse().getSelectedParses().forEach(p -> {
|
||||
if (PluginQueryManager.isPluginQuery(p.getQueryMode())
|
||||
|| "PLAIN_TEXT".equals(p.getQueryMode())) {
|
||||
if (Objects.isNull(p.getDataSet()) || Objects.isNull(p.getSqlInfo().getParsedS2SQL())) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
public enum Text2SQLType {
|
||||
ONLY_RULE, LLM_OR_RULE;
|
||||
ONLY_RULE, LLM_OR_RULE, NONE;
|
||||
|
||||
public boolean enableLLM() {
|
||||
return this.equals(LLM_OR_RULE);
|
||||
|
||||
@@ -26,26 +26,27 @@ import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT;
|
||||
public class SemanticParseInfo implements Serializable {
|
||||
|
||||
private Integer id;
|
||||
private String queryMode = "PLAIN_TEXT";
|
||||
private String queryMode = "";
|
||||
private QueryConfig queryConfig;
|
||||
private QueryType queryType = QueryType.DETAIL;
|
||||
private QueryType queryType;
|
||||
|
||||
private SchemaElement dataSet;
|
||||
private Set<SchemaElement> metrics = Sets.newTreeSet(new SchemaNameLengthComparator());
|
||||
private Set<SchemaElement> dimensions = Sets.newTreeSet(new SchemaNameLengthComparator());
|
||||
|
||||
private Set<QueryFilter> dimensionFilters = Sets.newHashSet();
|
||||
private Set<QueryFilter> metricFilters = Sets.newHashSet();
|
||||
private FilterType filterType = FilterType.AND;
|
||||
|
||||
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
|
||||
private FilterType filterType = FilterType.AND;
|
||||
private Set<Order> orders = Sets.newHashSet();
|
||||
private DateConf dateInfo;
|
||||
private long limit = DEFAULT_DETAIL_LIMIT;
|
||||
private double score;
|
||||
private List<SchemaElementMatch> elementMatches = Lists.newArrayList();
|
||||
private DateConf dateInfo;
|
||||
private SqlInfo sqlInfo = new SqlInfo();
|
||||
private SqlEvaluation sqlEvaluation = new SqlEvaluation();
|
||||
private String textInfo;
|
||||
private SqlEvaluation sqlEvaluation = new SqlEvaluation();
|
||||
private Map<String, Object> properties = Maps.newHashMap();
|
||||
|
||||
@Data
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
@@ -41,6 +42,10 @@ public class ChatQueryContext implements Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
public boolean needSQL() {
|
||||
return !request.getText2SQLType().equals(Text2SQLType.NONE);
|
||||
}
|
||||
|
||||
public DataSetSchema getDataSetSchema(Long dataSetId) {
|
||||
return semanticSchema.getDataSetSchema(dataSetId);
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -31,9 +30,11 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
try {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||
String s2SQL = semanticParseInfo.getSqlInfo().getParsedS2SQL();
|
||||
if (Objects.isNull(s2SQL)) {
|
||||
return;
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
|
||||
doCorrect(chatQueryContext, semanticParseInfo);
|
||||
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(),
|
||||
semanticParseInfo.getSqlInfo());
|
||||
|
||||
@@ -6,6 +6,8 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */
|
||||
@Slf4j
|
||||
public class QueryTypeParser implements SemanticParser {
|
||||
@@ -15,12 +17,14 @@ public class QueryTypeParser implements SemanticParser {
|
||||
chatQueryContext.getCandidateQueries().forEach(query -> {
|
||||
SemanticParseInfo parseInfo = query.getParseInfo();
|
||||
String s2SQL = parseInfo.getSqlInfo().getParsedS2SQL();
|
||||
QueryType queryType = QueryType.DETAIL;
|
||||
if (Objects.isNull(s2SQL)) {
|
||||
return;
|
||||
}
|
||||
|
||||
QueryType queryType = QueryType.DETAIL;
|
||||
if (SqlSelectFunctionHelper.hasAggregateFunction(s2SQL)) {
|
||||
queryType = QueryType.AGGREGATE;
|
||||
}
|
||||
|
||||
parseInfo.setQueryType(queryType);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -49,7 +49,6 @@ public class LLMResponseService {
|
||||
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,9 @@ public class RuleSqlParser implements SemanticParser {
|
||||
|
||||
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
|
||||
|
||||
candidateQueries.forEach(query -> query.buildS2Sql(
|
||||
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
|
||||
if (chatQueryContext.needSQL()) {
|
||||
candidateQueries.forEach(query -> query.buildS2Sql(
|
||||
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,17 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
@@ -45,22 +46,23 @@ public class LLMReq {
|
||||
private SchemaElement primaryKey;
|
||||
|
||||
public List<String> getFieldNameList() {
|
||||
List<String> fieldNameList = new ArrayList<>();
|
||||
Set<String> fieldNameList = new HashSet<>();
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
fieldNameList.addAll(metrics.stream().map(metric -> metric.getName())
|
||||
.collect(Collectors.toList()));
|
||||
fieldNameList.addAll(
|
||||
metrics.stream().map(SchemaElement::getName).collect(Collectors.toList()));
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(dimensions)) {
|
||||
fieldNameList.addAll(dimensions.stream().map(dimension -> dimension.getName())
|
||||
fieldNameList.addAll(dimensions.stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(values)) {
|
||||
fieldNameList.addAll(values.stream().map(ElementValue::getFieldName)
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
if (Objects.nonNull(partitionTime)) {
|
||||
fieldNameList.add(partitionTime.getName());
|
||||
}
|
||||
if (Objects.nonNull(primaryKey)) {
|
||||
fieldNameList.add(primaryKey.getName());
|
||||
}
|
||||
return fieldNameList;
|
||||
return new ArrayList<>(fieldNameList);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +76,7 @@ public class LLMReq {
|
||||
public enum SqlGenType {
|
||||
ONE_PASS_SELF_CONSISTENCY("1_pass_self_consistency");
|
||||
|
||||
private String name;
|
||||
private final String name;
|
||||
|
||||
SqlGenType(String name) {
|
||||
this.name = name;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -23,8 +22,5 @@ public class LLMSqlQuery extends LLMSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void buildS2Sql(DataSetSchema dataSetSchema) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
|
||||
}
|
||||
public void buildS2Sql(DataSetSchema dataSetSchema) {}
|
||||
}
|
||||
|
||||
@@ -60,7 +60,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
convertBizNameToName(dataSetSchema, queryStructReq);
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
|
||||
}
|
||||
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.query.rule.metric;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
@@ -52,8 +51,6 @@ public class MetricTopNQuery extends MetricSemanticQuery {
|
||||
super.fillParseInfo(chatQueryContext, dataSetId);
|
||||
|
||||
parseInfo.setScore(parseInfo.getScore() + 2.0);
|
||||
parseInfo.setAggType(AggregateTypeEnum.SUM);
|
||||
|
||||
SchemaElement metric = parseInfo.getMetrics().iterator().next();
|
||||
parseInfo.getOrders().add(new Order(metric.getBizName(), Constants.DESC_UPPER));
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
@@ -128,10 +127,7 @@ public class S2ChatLayerService implements ChatLayerService {
|
||||
schemaService.getSemanticSchema(Sets.newHashSet(querySqlReq.getDataSetId()));
|
||||
queryCtx.setSemanticSchema(semanticSchema);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
sqlInfo.setCorrectedS2SQL(querySqlReq.getSql());
|
||||
sqlInfo.setParsedS2SQL(querySqlReq.getSql());
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
semanticParseInfo.getSqlInfo().setParsedS2SQL(querySqlReq.getSql());
|
||||
semanticParseInfo.setQueryType(QueryType.DETAIL);
|
||||
|
||||
Long dataSetId = querySqlReq.getDataSetId();
|
||||
@@ -147,7 +143,7 @@ public class S2ChatLayerService implements ChatLayerService {
|
||||
corrector.correct(queryCtx, semanticParseInfo);
|
||||
}
|
||||
});
|
||||
log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectedS2SQL());
|
||||
log.info("Corrected SQL:{}", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
return semanticParseInfo;
|
||||
}
|
||||
|
||||
|
||||
@@ -60,7 +60,12 @@ public class ChatWorkflowEngine {
|
||||
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
|
||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||
parseResult.setSelectedParses(parseInfos);
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
|
||||
if (queryCtx.needSQL()) {
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
|
||||
} else {
|
||||
parseResult.setState(ParseResp.ParseState.COMPLETED);
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case CORRECTING:
|
||||
|
||||
@@ -20,6 +20,7 @@ import java.text.SimpleDateFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.MAX;
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
|
||||
|
||||
@@ -46,6 +47,7 @@ public class MetricTest extends BaseTest {
|
||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
assert actualResult.getQueryResults().size() == 1;
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -67,6 +69,7 @@ public class MetricTest extends BaseTest {
|
||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
assert actualResult.getQueryResults().size() == 4;
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -93,6 +96,7 @@ public class MetricTest extends BaseTest {
|
||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
assert actualResult.getQueryResults().size() == 2;
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -105,7 +109,7 @@ public class MetricTest extends BaseTest {
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(MetricTopNQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(SUM);
|
||||
expectedParseInfo.setAggType(MAX);
|
||||
|
||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户"));
|
||||
@@ -135,6 +139,7 @@ public class MetricTest extends BaseTest {
|
||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
assert actualResult.getQueryResults().size() == 4;
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -144,7 +149,7 @@ public class MetricTest extends BaseTest {
|
||||
String dateStr = textFormat.format(format.parse(startDay));
|
||||
|
||||
QueryResult actualResult =
|
||||
submitNewChat(String.format("想知道%salice的访问次数", dateStr), DataUtils.metricAgentId);
|
||||
submitNewChat(String.format("alice在%s的访问次数", dateStr), DataUtils.metricAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
|
||||
Reference in New Issue
Block a user