Merge fixes and improvements (#1910)

Co-authored-by: tristanliu <tristanliu@tencent.com>
This commit is contained in:
Jun Zhang
2024-11-16 13:57:54 +08:00
committed by GitHub
parent 5e22b412c6
commit ba1938f04b
40 changed files with 1382 additions and 2784 deletions

View File

@@ -30,11 +30,10 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
try {
String s2SQL = semanticParseInfo.getSqlInfo().getParsedS2SQL();
String s2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (Objects.isNull(s2SQL)) {
return;
}
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
doCorrect(chatQueryContext, semanticParseInfo);
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(),
semanticParseInfo.getSqlInfo());

View File

@@ -6,8 +6,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import java.util.Objects;
/** QueryTypeParser resolves query type as either AGGREGATE or DETAIL */
@Slf4j
public class QueryTypeParser implements SemanticParser {
@@ -17,14 +15,12 @@ public class QueryTypeParser implements SemanticParser {
chatQueryContext.getCandidateQueries().forEach(query -> {
SemanticParseInfo parseInfo = query.getParseInfo();
String s2SQL = parseInfo.getSqlInfo().getParsedS2SQL();
if (Objects.isNull(s2SQL)) {
return;
}
QueryType queryType = QueryType.DETAIL;
if (SqlSelectFunctionHelper.hasAggregateFunction(s2SQL)) {
queryType = QueryType.AGGREGATE;
}
parseInfo.setQueryType(queryType);
});
}

View File

@@ -49,6 +49,7 @@ public class LLMResponseService {
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
parseInfo.getSqlInfo().setCorrectedS2SQL(s2SQL);
queryCtx.getCandidateQueries().add(semanticQuery);
}

View File

@@ -40,9 +40,7 @@ public class RuleSqlParser implements SemanticParser {
auxiliaryParsers.forEach(p -> p.parse(chatQueryContext));
if (chatQueryContext.needSQL()) {
candidateQueries.forEach(query -> query.buildS2Sql(
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
}
candidateQueries.forEach(query -> query.buildS2Sql(
chatQueryContext.getDataSetSchema(query.getParseInfo().getDataSetId())));
}
}

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
import lombok.extern.slf4j.Slf4j;
@@ -22,5 +23,8 @@ public class LLMSqlQuery extends LLMSemanticQuery {
}
@Override
public void buildS2Sql(DataSetSchema dataSetSchema) {}
public void buildS2Sql(DataSetSchema dataSetSchema) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
}
}

View File

@@ -60,6 +60,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
convertBizNameToName(dataSetSchema, queryStructReq);
QuerySqlReq querySQLReq = queryStructReq.convert();
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
}
protected QueryStructReq convertQueryStruct() {

View File

@@ -147,6 +147,12 @@ public class QueryReqBuilder {
return aggregateType.name();
}
private static boolean isDateFieldAlreadyPresent(SemanticParseInfo parseInfo,
String dateField) {
return parseInfo.getDimensions().stream()
.anyMatch(dimension -> dimension.getBizName().equalsIgnoreCase(dateField));
}
public static Set<Order> getOrder(Set<Order> existingOrders, AggregateTypeEnum aggregator,
SchemaElement metric) {
if (existingOrders != null && !existingOrders.isEmpty()) {