[improvement][project]Simplify code logic in multiple modules.

This commit is contained in:
jerryjzhang
2024-11-10 14:31:12 +08:00
parent ca4545bb15
commit 14a19a901f
18 changed files with 65 additions and 47 deletions

View File

@@ -26,7 +26,7 @@ import java.util.stream.Collectors;
public class PlainTextExecutor implements ChatQueryExecutor {
public static final String APP_KEY = "SMALL_TALK";
private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to."
private static final String INSTRUCTION = "#Role: You are a nice person to talk to."
+ "\n#Task: Respond quickly and nicely to the user."
+ "\n#Rules: 1.ALWAYS use the same language as the `#Current Input`."
+ "\n#History Inputs: %s" + "\n#Current Input: %s" + "\n#Response: ";

View File

@@ -83,6 +83,9 @@ public class NL2SQLParser implements ChatQueryParser {
if (Objects.isNull(parseContext.getRequest().getSelectedParse())) {
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
if (parseContext.enableLLM()) {
queryNLReq.setText2SQLType(Text2SQLType.NONE);
}
// for every requested dataSet, recursively invoke rule-based parser with different
// mapModes

View File

@@ -7,7 +7,7 @@ import java.util.Map;
public class PluginQueryManager {
private static Map<String, PluginSemanticQuery> pluginQueries = new HashMap<>();
private static final Map<String, PluginSemanticQuery> pluginQueries = new HashMap<>();
public static void register(String queryMode, PluginSemanticQuery pluginSemanticQuery) {
pluginQueries.put(queryMode, pluginSemanticQuery);

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
@@ -22,7 +21,12 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
@@ -33,8 +37,7 @@ public class ParseInfoFormatProcessor implements ParseResultProcessor {
@Override
public void process(ParseContext parseContext) {
parseContext.getResponse().getSelectedParses().forEach(p -> {
if (PluginQueryManager.isPluginQuery(p.getQueryMode())
|| "PLAIN_TEXT".equals(p.getQueryMode())) {
if (Objects.isNull(p.getDataSet()) || Objects.isNull(p.getSqlInfo().getParsedS2SQL())) {
return;
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.common.pojo.enums;
public enum Text2SQLType {
ONLY_RULE, LLM_OR_RULE;
ONLY_RULE, LLM_OR_RULE, NONE;
public boolean enableLLM() {
return this.equals(LLM_OR_RULE);

View File

@@ -26,26 +26,27 @@ import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT;
public class SemanticParseInfo implements Serializable {
private Integer id;
private String queryMode = "PLAIN_TEXT";
private String queryMode = "";
private QueryConfig queryConfig;
private QueryType queryType = QueryType.DETAIL;
private QueryType queryType;
private SchemaElement dataSet;
private Set<SchemaElement> metrics = Sets.newTreeSet(new SchemaNameLengthComparator());
private Set<SchemaElement> dimensions = Sets.newTreeSet(new SchemaNameLengthComparator());
private Set<QueryFilter> dimensionFilters = Sets.newHashSet();
private Set<QueryFilter> metricFilters = Sets.newHashSet();
private FilterType filterType = FilterType.AND;
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
private FilterType filterType = FilterType.AND;
private Set<Order> orders = Sets.newHashSet();
private DateConf dateInfo;
private long limit = DEFAULT_DETAIL_LIMIT;
private double score;
private List<SchemaElementMatch> elementMatches = Lists.newArrayList();
private DateConf dateInfo;
private SqlInfo sqlInfo = new SqlInfo();
private SqlEvaluation sqlEvaluation = new SqlEvaluation();
private String textInfo;
private SqlEvaluation sqlEvaluation = new SqlEvaluation();
private Map<String, Object> properties = Maps.newHashMap();
@Data

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -41,6 +42,10 @@ public class ChatQueryContext implements Serializable {
}
}
public boolean needSQL() {
return !request.getText2SQLType().equals(Text2SQLType.NONE);
}
public DataSetSchema getDataSetSchema(Long dataSetId) {
return semanticSchema.getDataSetSchema(dataSetId);
}

View File

@@ -10,7 +10,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
@@ -31,9 +30,11 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
try {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) {
String s2SQL = semanticParseInfo.getSqlInfo().getParsedS2SQL();
if (Objects.isNull(s2SQL)) {
return;
}
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
doCorrect(chatQueryContext, semanticParseInfo);
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(),
semanticParseInfo.getSqlInfo());

View File

@@ -6,6 +6,8 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import java.util.Objects;
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */
@Slf4j
public class QueryTypeParser implements SemanticParser {
@@ -15,12 +17,14 @@ public class QueryTypeParser implements SemanticParser {
chatQueryContext.getCandidateQueries().forEach(query -> {
SemanticParseInfo parseInfo = query.getParseInfo();
String s2SQL = parseInfo.getSqlInfo().getParsedS2SQL();
QueryType queryType = QueryType.DETAIL;
if (Objects.isNull(s2SQL)) {
return;
}
QueryType queryType = QueryType.DETAIL;
if (SqlSelectFunctionHelper.hasAggregateFunction(s2SQL)) {
queryType = QueryType.AGGREGATE;
}
parseInfo.setQueryType(queryType);
});
}

View File

@@ -49,7 +49,6 @@ public class LLMResponseService {
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
queryCtx.getCandidateQueries().add(semanticQuery);
}

View File

@@ -40,7 +40,9 @@ public class RuleSqlParser implements SemanticParser {
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
candidateQueries.forEach(query -> query.buildS2Sql(
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
if (chatQueryContext.needSQL()) {
candidateQueries.forEach(query -> query.buildS2Sql(
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
}
}
}

View File

@@ -3,16 +3,17 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import lombok.Data;
import org.apache.commons.collections4.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Data
@@ -45,22 +46,23 @@ public class LLMReq {
private SchemaElement primaryKey;
public List<String> getFieldNameList() {
List<String> fieldNameList = new ArrayList<>();
Set<String> fieldNameList = new HashSet<>();
if (CollectionUtils.isNotEmpty(metrics)) {
fieldNameList.addAll(metrics.stream().map(metric -> metric.getName())
.collect(Collectors.toList()));
fieldNameList.addAll(
metrics.stream().map(SchemaElement::getName).collect(Collectors.toList()));
}
if (CollectionUtils.isNotEmpty(dimensions)) {
fieldNameList.addAll(dimensions.stream().map(dimension -> dimension.getName())
fieldNameList.addAll(dimensions.stream().map(SchemaElement::getName)
.collect(Collectors.toList()));
}
if (CollectionUtils.isNotEmpty(values)) {
fieldNameList.addAll(values.stream().map(ElementValue::getFieldName)
.collect(Collectors.toList()));
}
if (Objects.nonNull(partitionTime)) {
fieldNameList.add(partitionTime.getName());
}
if (Objects.nonNull(primaryKey)) {
fieldNameList.add(primaryKey.getName());
}
return fieldNameList;
return new ArrayList<>(fieldNameList);
}
}
@@ -74,7 +76,7 @@ public class LLMReq {
public enum SqlGenType {
ONE_PASS_SELF_CONSISTENCY("1_pass_self_consistency");
private String name;
private final String name;
SqlGenType(String name) {
this.name = name;

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
import lombok.extern.slf4j.Slf4j;
@@ -23,8 +22,5 @@ public class LLMSqlQuery extends LLMSemanticQuery {
}
@Override
public void buildS2Sql(DataSetSchema dataSetSchema) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
}
public void buildS2Sql(DataSetSchema dataSetSchema) {}
}

View File

@@ -60,7 +60,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
convertBizNameToName(dataSetSchema, queryStructReq);
QuerySqlReq querySQLReq = queryStructReq.convert();
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
}
protected QueryStructReq convertQueryStruct() {

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.query.rule.metric;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
@@ -52,8 +51,6 @@ public class MetricTopNQuery extends MetricSemanticQuery {
super.fillParseInfo(chatQueryContext, dataSetId);
parseInfo.setScore(parseInfo.getScore() + 2.0);
parseInfo.setAggType(AggregateTypeEnum.SUM);
SchemaElement metric = parseInfo.getMetrics().iterator().next();
parseInfo.getOrders().add(new Order(metric.getBizName(), Constants.DESC_UPPER));
}

View File

@@ -14,7 +14,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
@@ -128,10 +127,7 @@ public class S2ChatLayerService implements ChatLayerService {
schemaService.getSemanticSchema(Sets.newHashSet(querySqlReq.getDataSetId()));
queryCtx.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setCorrectedS2SQL(querySqlReq.getSql());
sqlInfo.setParsedS2SQL(querySqlReq.getSql());
semanticParseInfo.setSqlInfo(sqlInfo);
semanticParseInfo.getSqlInfo().setParsedS2SQL(querySqlReq.getSql());
semanticParseInfo.setQueryType(QueryType.DETAIL);
Long dataSetId = querySqlReq.getDataSetId();
@@ -147,7 +143,7 @@ public class S2ChatLayerService implements ChatLayerService {
corrector.correct(queryCtx, semanticParseInfo);
}
});
log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectedS2SQL());
log.info("Corrected SQL:{}", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
return semanticParseInfo;
}

View File

@@ -60,7 +60,12 @@ public class ChatWorkflowEngine {
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
parseResult.setSelectedParses(parseInfos);
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
if (queryCtx.needSQL()) {
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
} else {
parseResult.setState(ParseResp.ParseState.COMPLETED);
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
}
}
break;
case CORRECTING:

View File

@@ -20,6 +20,7 @@ import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.List;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.MAX;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
@@ -46,6 +47,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
assertQueryResult(expectedResult, actualResult);
assert actualResult.getQueryResults().size() == 1;
}
@Test
@@ -67,6 +69,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
assertQueryResult(expectedResult, actualResult);
assert actualResult.getQueryResults().size() == 4;
}
@Test
@@ -93,6 +96,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
assertQueryResult(expectedResult, actualResult);
assert actualResult.getQueryResults().size() == 2;
}
@Test
@@ -105,7 +109,7 @@ public class MetricTest extends BaseTest {
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(MetricTopNQuery.QUERY_MODE);
expectedParseInfo.setAggType(SUM);
expectedParseInfo.setAggType(MAX);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户"));
@@ -135,6 +139,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
assertQueryResult(expectedResult, actualResult);
assert actualResult.getQueryResults().size() == 4;
}
@Test
@@ -144,7 +149,7 @@ public class MetricTest extends BaseTest {
String dateStr = textFormat.format(format.parse(startDay));
QueryResult actualResult =
submitNewChat(String.format("想知道%salice的访问次数", dateStr), DataUtils.metricAgentId);
submitNewChat(String.format("alice在%s的访问次数", dateStr), DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();