mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 20:25:12 +00:00
[improvement][project]Simplify code logic in multiple modules.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
@@ -41,6 +42,10 @@ public class ChatQueryContext implements Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
public boolean needSQL() {
|
||||
return !request.getText2SQLType().equals(Text2SQLType.NONE);
|
||||
}
|
||||
|
||||
public DataSetSchema getDataSetSchema(Long dataSetId) {
|
||||
return semanticSchema.getDataSetSchema(dataSetId);
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -31,9 +30,11 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
try {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||
String s2SQL = semanticParseInfo.getSqlInfo().getParsedS2SQL();
|
||||
if (Objects.isNull(s2SQL)) {
|
||||
return;
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
|
||||
doCorrect(chatQueryContext, semanticParseInfo);
|
||||
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(),
|
||||
semanticParseInfo.getSqlInfo());
|
||||
|
||||
@@ -6,6 +6,8 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */
|
||||
@Slf4j
|
||||
public class QueryTypeParser implements SemanticParser {
|
||||
@@ -15,12 +17,14 @@ public class QueryTypeParser implements SemanticParser {
|
||||
chatQueryContext.getCandidateQueries().forEach(query -> {
|
||||
SemanticParseInfo parseInfo = query.getParseInfo();
|
||||
String s2SQL = parseInfo.getSqlInfo().getParsedS2SQL();
|
||||
QueryType queryType = QueryType.DETAIL;
|
||||
if (Objects.isNull(s2SQL)) {
|
||||
return;
|
||||
}
|
||||
|
||||
QueryType queryType = QueryType.DETAIL;
|
||||
if (SqlSelectFunctionHelper.hasAggregateFunction(s2SQL)) {
|
||||
queryType = QueryType.AGGREGATE;
|
||||
}
|
||||
|
||||
parseInfo.setQueryType(queryType);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -49,7 +49,6 @@ public class LLMResponseService {
|
||||
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,9 @@ public class RuleSqlParser implements SemanticParser {
|
||||
|
||||
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
|
||||
|
||||
candidateQueries.forEach(query -> query.buildS2Sql(
|
||||
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
|
||||
if (chatQueryContext.needSQL()) {
|
||||
candidateQueries.forEach(query -> query.buildS2Sql(
|
||||
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,16 +3,17 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
@@ -45,22 +46,23 @@ public class LLMReq {
|
||||
private SchemaElement primaryKey;
|
||||
|
||||
public List<String> getFieldNameList() {
|
||||
List<String> fieldNameList = new ArrayList<>();
|
||||
Set<String> fieldNameList = new HashSet<>();
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
fieldNameList.addAll(metrics.stream().map(metric -> metric.getName())
|
||||
.collect(Collectors.toList()));
|
||||
fieldNameList.addAll(
|
||||
metrics.stream().map(SchemaElement::getName).collect(Collectors.toList()));
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(dimensions)) {
|
||||
fieldNameList.addAll(dimensions.stream().map(dimension -> dimension.getName())
|
||||
fieldNameList.addAll(dimensions.stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(values)) {
|
||||
fieldNameList.addAll(values.stream().map(ElementValue::getFieldName)
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
if (Objects.nonNull(partitionTime)) {
|
||||
fieldNameList.add(partitionTime.getName());
|
||||
}
|
||||
if (Objects.nonNull(primaryKey)) {
|
||||
fieldNameList.add(primaryKey.getName());
|
||||
}
|
||||
return fieldNameList;
|
||||
return new ArrayList<>(fieldNameList);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,7 +76,7 @@ public class LLMReq {
|
||||
public enum SqlGenType {
|
||||
ONE_PASS_SELF_CONSISTENCY("1_pass_self_consistency");
|
||||
|
||||
private String name;
|
||||
private final String name;
|
||||
|
||||
SqlGenType(String name) {
|
||||
this.name = name;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -23,8 +22,5 @@ public class LLMSqlQuery extends LLMSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void buildS2Sql(DataSetSchema dataSetSchema) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
|
||||
}
|
||||
public void buildS2Sql(DataSetSchema dataSetSchema) {}
|
||||
}
|
||||
|
||||
@@ -60,7 +60,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
convertBizNameToName(dataSetSchema, queryStructReq);
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
|
||||
}
|
||||
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.query.rule.metric;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
@@ -52,8 +51,6 @@ public class MetricTopNQuery extends MetricSemanticQuery {
|
||||
super.fillParseInfo(chatQueryContext, dataSetId);
|
||||
|
||||
parseInfo.setScore(parseInfo.getScore() + 2.0);
|
||||
parseInfo.setAggType(AggregateTypeEnum.SUM);
|
||||
|
||||
SchemaElement metric = parseInfo.getMetrics().iterator().next();
|
||||
parseInfo.getOrders().add(new Order(metric.getBizName(), Constants.DESC_UPPER));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user