(improvement)(headless)Remove unnecessary TranslateSqlReq, use SemanticQueryReq instead.

This commit is contained in:
jerryjzhang
2024-07-09 10:48:48 +08:00
parent 7a376bd9a3
commit f0b4eb46cf
32 changed files with 138 additions and 176 deletions

View File

@@ -22,7 +22,7 @@ public class AggCorrector extends BaseSemanticCorrector {
private void addAggregate(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
return;
}

View File

@@ -29,7 +29,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
try {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) {
return;
}
doCorrect(chatQueryContext, semanticParseInfo);
@@ -74,7 +74,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
List<SchemaElement> metrics = getMetricElements(chatQueryContext, dataSetId);
@@ -98,7 +98,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return;
}
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(aggregateSql);
}
protected List<SchemaElement> getMetricElements(ChatQueryContext chatQueryContext, Long dataSetId) {

View File

@@ -34,8 +34,8 @@ public class GrammarCorrector extends BaseSemanticCorrector {
}
public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}

View File

@@ -35,7 +35,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
Long dataSetId = semanticParseInfo.getDataSetId();
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
// check has distinct
if (SqlSelectHelper.hasDistinct(correctS2SQL)) {
@@ -68,7 +68,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
Long dataSetId = semanticParseInfo.getDataSetId();
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
//add alias field name
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
@@ -83,7 +83,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
return true;
})
.collect(Collectors.toSet());
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
}
}

View File

@@ -49,19 +49,19 @@ public class HavingCorrector extends BaseSemanticCorrector {
if (CollectionUtils.isEmpty(metrics)) {
return;
}
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(), metrics);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(havingSql);
}
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
return;
}
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
if (!CollectionUtils.isEmpty(havingExpressionList)) {
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(replaceSql);
}
return;
}

View File

@@ -50,21 +50,21 @@ public class SchemaCorrector extends BaseSemanticCorrector {
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
sqlInfo.setCorrectS2SQL(sql);
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectedS2SQL(), aggregateEnum);
sqlInfo.setCorrectedS2SQL(sql);
}
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
sqlInfo.setCorrectS2SQL(replaceAlias);
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectedS2SQL());
sqlInfo.setCorrectedS2SQL(replaceAlias);
}
private void correctFieldName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
sqlInfo.setCorrectS2SQL(sql);
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectedS2SQL(), fieldNameMap);
sqlInfo.setCorrectedS2SQL(sql);
}
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
@@ -79,8 +79,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
sqlInfo.setCorrectS2SQL(sql);
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectedS2SQL(), fieldValueToFieldNames);
sqlInfo.setCorrectedS2SQL(sql);
}
private List<LLMReq.ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
@@ -111,14 +111,14 @@ public class SchemaCorrector extends BaseSemanticCorrector {
)));
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
sqlInfo.setCorrectS2SQL(sql);
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectedS2SQL(), filedNameToValueMap, false);
sqlInfo.setCorrectedS2SQL(sql);
}
public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
if (CollectionUtils.isEmpty(whereExpressionList)) {
return;
@@ -143,7 +143,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
}
}

View File

@@ -33,7 +33,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
@@ -44,7 +44,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
}
correctS2SQL = addFieldsToSelect(chatQueryContext, semanticParseInfo, correctS2SQL);
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(querySql);
}
protected String addFieldsToSelect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo,
@@ -65,7 +65,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
}
needAddFields.removeAll(selectFields);
String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
semanticParseInfo.getSqlInfo().setCorrectS2SQL(addFieldsToSelectSql);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
return addFieldsToSelectSql;
}

View File

@@ -45,7 +45,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
}
private void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
//decide whether remove date field from where
Environment environment = ContextUtils.getBean(Environment.class);
String correctorDate = environment.getProperty("s2.corrector.date");
@@ -55,12 +55,12 @@ public class TimeCorrector extends BaseSemanticCorrector {
removeFieldNames.add(TimeDimensionEnum.WEEK.getChName());
removeFieldNames.add(TimeDimensionEnum.MONTH.getChName());
correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}
private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
//decide whether add date field to where
@@ -88,11 +88,11 @@ public class TimeCorrector extends BaseSemanticCorrector {
}
}
}
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
if (Objects.isNull(dateBoundInfo)) {
return;
@@ -108,14 +108,14 @@ public class TimeCorrector extends BaseSemanticCorrector {
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}

View File

@@ -39,7 +39,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters());
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
@@ -50,7 +50,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
log.error("parseCondExpression", e);
}
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
}
@@ -71,9 +71,9 @@ public class WhereCorrector extends BaseSemanticCorrector {
}
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(),
aliasAndBizNameToTechName);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {

View File

@@ -46,14 +46,14 @@ public class QueryTypeParser implements SemanticParser {
private QueryType getQueryType(ChatQueryContext chatQueryContext, SemanticQuery semanticQuery) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getParsedS2SQL())) {
return QueryType.DETAIL;
}
//1. entity queryType
Long dataSetId = parseInfo.getDataSetId();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL());
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getParsedS2SQL());
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
@@ -63,7 +63,7 @@ public class QueryTypeParser implements SemanticParser {
return QueryType.ID;
}
}
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
selectFields.addAll(whereFields);
List<String> selectWhereFilterByTimeFields = filterByTimeFields(selectFields);
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
@@ -91,7 +91,7 @@ public class QueryTypeParser implements SemanticParser {
}
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());

View File

@@ -38,7 +38,7 @@ public class LLMResponseService {
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setS2SQL(s2SQL);
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
queryCtx.getCandidateQueries().add(semanticQuery);
return parseInfo;
}

View File

@@ -83,8 +83,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
QueryStructReq queryStructReq = convertQueryStruct();
convertBizNameToName(semanticSchema, queryStructReq);
QuerySqlReq querySQLReq = queryStructReq.convert();
parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql());
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
}
}

View File

@@ -33,6 +33,6 @@ public class LLMSqlQuery extends LLMSemanticQuery {
@Override
public void initS2Sql(SemanticSchema semanticSchema, User user) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
sqlInfo.setCorrectS2SQL(sqlInfo.getS2SQL());
sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL());
}
}

View File

@@ -152,8 +152,8 @@ public class QueryReqBuilder {
public static QuerySqlReq buildS2SQLReq(SqlInfo sqlInfo, Long dataSetId) {
QuerySqlReq querySQLReq = new QuerySqlReq();
if (Objects.nonNull(sqlInfo.getCorrectS2SQL())) {
querySQLReq.setSql(sqlInfo.getCorrectS2SQL());
if (Objects.nonNull(sqlInfo.getCorrectedS2SQL())) {
querySQLReq.setSql(sqlInfo.getCorrectedS2SQL());
}
querySQLReq.setSqlInfo(sqlInfo);
querySQLReq.setDataSetId(dataSetId);

View File

@@ -30,14 +30,14 @@ class AggCorrectorTest {
String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
+ " datediff('day', 数据日期, '2024-06-04') <= 7"
+ " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1";
sqlInfo.setS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.correct(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'"
+ " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户"
+ " ORDER BY SUM(访问次数) DESC LIMIT 1",
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
}
private ChatQueryContext buildQueryContext(Long dataSetId) {

View File

@@ -65,8 +65,8 @@ class SchemaCorrectorTest {
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
SchemaElement schemaElement = new SchemaElement();
@@ -80,7 +80,7 @@ class SchemaCorrectorTest {
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
parseResult = objectMapper.readValue(json, ParseResult.class);
@@ -92,11 +92,11 @@ class SchemaCorrectorTest {
parseResult.setLinkingValues(linkingValues);
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
semanticParseInfo.getSqlInfo().setS2SQL(sql);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql);
semanticParseInfo.getSqlInfo().setParsedS2SQL(sql);
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
}

View File

@@ -44,12 +44,12 @@ class SelectCorrectorTest {
semanticParseInfo.setQueryType(QueryType.DETAIL);
SqlInfo sqlInfo = new SqlInfo();
String sql = "SELECT * FROM 艺人库 WHERE 艺人名='周杰伦'";
sqlInfo.setS2SQL(sql);
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.correct(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
}
private ChatQueryContext buildQueryContext(Long dataSetId) {

View File

@@ -19,84 +19,84 @@ class TimeCorrectorTest {
//1.数据日期 <=
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//2.数据日期 <
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
+ "AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//3.数据日期 >=
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//4.数据日期 >
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//5.no 数据日期
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//6. 数据日期-月 <=
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
+ "AND 数据日期_月 >= '2024-01' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//7. 数据日期-月 >
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1",
sqlInfo.getCorrectS2SQL());
sqlInfo.getCorrectedS2SQL());
//8. no where
sql = "SELECT COUNT(1) FROM 数据库";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectedS2SQL());
}
}

View File

@@ -19,7 +19,7 @@ class WhereCorrectorTest {
SqlInfo sqlInfo = new SqlInfo();
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
ChatQueryContext chatQueryContext = new ChatQueryContext();
@@ -54,7 +54,7 @@ class WhereCorrectorTest {
WhereCorrector whereCorrector = new WhereCorrector();
whereCorrector.addQueryFilter(chatQueryContext, semanticParseInfo);
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
Assert.assertEquals(correctS2SQL, "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE "
+ "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND "