mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
(improvement)(headless)Remove unnecessary TranslateSqlReq, use SemanticQueryReq instead.
This commit is contained in:
@@ -22,7 +22,7 @@ public class AggCorrector extends BaseSemanticCorrector {
|
||||
|
||||
private void addAggregate(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
try {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||
return;
|
||||
}
|
||||
doCorrect(chatQueryContext, semanticParseInfo);
|
||||
@@ -74,7 +74,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
|
||||
List<SchemaElement> metrics = getMetricElements(chatQueryContext, dataSetId);
|
||||
|
||||
@@ -98,7 +98,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
return;
|
||||
}
|
||||
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMetricElements(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
|
||||
@@ -34,8 +34,8 @@ public class GrammarCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
// check has distinct
|
||||
if (SqlSelectHelper.hasDistinct(correctS2SQL)) {
|
||||
@@ -68,7 +68,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
//add alias field name
|
||||
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
||||
@@ -83,7 +83,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -49,19 +49,19 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
}
|
||||
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(havingSql);
|
||||
}
|
||||
|
||||
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
|
||||
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -50,21 +50,21 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectedS2SQL(), aggregateEnum);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectedS2SQL());
|
||||
sqlInfo.setCorrectedS2SQL(replaceAlias);
|
||||
}
|
||||
|
||||
private void correctFieldName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectedS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||
@@ -79,8 +79,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
|
||||
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectedS2SQL(), fieldValueToFieldNames);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
private List<LLMReq.ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
|
||||
@@ -111,14 +111,14 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
)));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectedS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereExpressionList)) {
|
||||
return;
|
||||
@@ -143,7 +143,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
|
||||
|
||||
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||
@@ -44,7 +44,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
correctS2SQL = addFieldsToSelect(chatQueryContext, semanticParseInfo, correctS2SQL);
|
||||
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(querySql);
|
||||
}
|
||||
|
||||
protected String addFieldsToSelect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo,
|
||||
@@ -65,7 +65,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
needAddFields.removeAll(selectFields);
|
||||
String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(addFieldsToSelectSql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
|
||||
return addFieldsToSelectSql;
|
||||
}
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
//decide whether remove date field from where
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorDate = environment.getProperty("s2.corrector.date");
|
||||
@@ -55,12 +55,12 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
||||
removeFieldNames.add(TimeDimensionEnum.WEEK.getChName());
|
||||
removeFieldNames.add(TimeDimensionEnum.MONTH.getChName());
|
||||
correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||
|
||||
//decide whether add date field to where
|
||||
@@ -88,11 +88,11 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
}
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
|
||||
if (Objects.isNull(dateBoundInfo)) {
|
||||
return;
|
||||
@@ -108,14 +108,14 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -39,7 +39,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters());
|
||||
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
|
||||
@@ -50,7 +50,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,9 +71,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(),
|
||||
aliasAndBizNameToTechName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
||||
|
||||
@@ -46,14 +46,14 @@ public class QueryTypeParser implements SemanticParser {
|
||||
private QueryType getQueryType(ChatQueryContext chatQueryContext, SemanticQuery semanticQuery) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getParsedS2SQL())) {
|
||||
return QueryType.DETAIL;
|
||||
}
|
||||
//1. entity queryType
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL());
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getParsedS2SQL());
|
||||
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
|
||||
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
|
||||
@@ -63,7 +63,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
return QueryType.ID;
|
||||
}
|
||||
}
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
|
||||
selectFields.addAll(whereFields);
|
||||
List<String> selectWhereFilterByTimeFields = filterByTimeFields(selectFields);
|
||||
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
|
||||
@@ -91,7 +91,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
|
||||
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
|
||||
@@ -38,7 +38,7 @@ public class LLMResponseService {
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setS2SQL(s2SQL);
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
@@ -83,8 +83,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
convertBizNameToName(semanticSchema, queryStructReq);
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||
parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -33,6 +33,6 @@ public class LLMSqlQuery extends LLMSemanticQuery {
|
||||
@Override
|
||||
public void initS2Sql(SemanticSchema semanticSchema, User user) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
sqlInfo.setCorrectS2SQL(sqlInfo.getS2SQL());
|
||||
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,8 +152,8 @@ public class QueryReqBuilder {
|
||||
|
||||
public static QuerySqlReq buildS2SQLReq(SqlInfo sqlInfo, Long dataSetId) {
|
||||
QuerySqlReq querySQLReq = new QuerySqlReq();
|
||||
if (Objects.nonNull(sqlInfo.getCorrectS2SQL())) {
|
||||
querySQLReq.setSql(sqlInfo.getCorrectS2SQL());
|
||||
if (Objects.nonNull(sqlInfo.getCorrectedS2SQL())) {
|
||||
querySQLReq.setSql(sqlInfo.getCorrectedS2SQL());
|
||||
}
|
||||
querySQLReq.setSqlInfo(sqlInfo);
|
||||
querySQLReq.setDataSetId(dataSetId);
|
||||
|
||||
@@ -30,14 +30,14 @@ class AggCorrectorTest {
|
||||
String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
|
||||
+ " datediff('day', 数据日期, '2024-06-04') <= 7"
|
||||
+ " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1";
|
||||
sqlInfo.setS2SQL(sql);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setParsedS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.correct(chatQueryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'"
|
||||
+ " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户"
|
||||
+ " ORDER BY SUM(访问次数) DESC LIMIT 1",
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
private ChatQueryContext buildQueryContext(Long dataSetId) {
|
||||
|
||||
@@ -65,8 +65,8 @@ class SchemaCorrectorTest {
|
||||
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
sqlInfo.setS2SQL(sql);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setParsedS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
@@ -80,7 +80,7 @@ class SchemaCorrectorTest {
|
||||
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
|
||||
parseResult = objectMapper.readValue(json, ParseResult.class);
|
||||
|
||||
@@ -92,11 +92,11 @@ class SchemaCorrectorTest {
|
||||
parseResult.setLinkingValues(linkingValues);
|
||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
|
||||
semanticParseInfo.getSqlInfo().setS2SQL(sql);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.getSqlInfo().setParsedS2SQL(sql);
|
||||
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -44,12 +44,12 @@ class SelectCorrectorTest {
|
||||
semanticParseInfo.setQueryType(QueryType.DETAIL);
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
String sql = "SELECT * FROM 艺人库 WHERE 艺人名='周杰伦'";
|
||||
sqlInfo.setS2SQL(sql);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setParsedS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.correct(chatQueryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
private ChatQueryContext buildQueryContext(Long dataSetId) {
|
||||
|
||||
@@ -19,84 +19,84 @@ class TimeCorrectorTest {
|
||||
//1.数据日期 <=
|
||||
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
|
||||
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//2.数据日期 <
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
|
||||
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//3.数据日期 >=
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//4.数据日期 >
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//5.no 数据日期
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//6. 数据日期-月 <=
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
|
||||
+ "AND 数据日期_月 >= '2024-01' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//7. 数据日期-月 >
|
||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
|
||||
Assert.assertEquals(
|
||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1",
|
||||
sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.getCorrectedS2SQL());
|
||||
|
||||
//8. no where
|
||||
sql = "SELECT COUNT(1) FROM 数据库";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
|
||||
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectedS2SQL());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ class WhereCorrectorTest {
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
|
||||
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||
@@ -54,7 +54,7 @@ class WhereCorrectorTest {
|
||||
WhereCorrector whereCorrector = new WhereCorrector();
|
||||
whereCorrector.addQueryFilter(chatQueryContext, semanticParseInfo);
|
||||
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
Assert.assertEquals(correctS2SQL, "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE "
|
||||
+ "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND "
|
||||
|
||||
Reference in New Issue
Block a user