[fix][headless]Fix schema corrector test cases. (#2027)

This commit is contained in:
Jun Zhang
2025-02-02 15:52:23 +08:00
committed by GitHub
parent d294fec2a0
commit 0417f12324
3 changed files with 14 additions and 12 deletions

View File

@@ -666,8 +666,9 @@ public class SqlSelectHelper {
}
if (withSelect instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) withSelect;
PlainSelect withPlainSelect = parenthesedSelect.getPlainSelect();
plainSelectList.add(withPlainSelect);
List<PlainSelect> plainSelects = new ArrayList<>();
SqlReplaceHelper.getFromSelect(parenthesedSelect, plainSelects);
plainSelectList.addAll(plainSelects);
}
}
}
@@ -893,7 +894,9 @@ public class SqlSelectHelper {
collectSelects(withItem.getSelect(), selects);
} else if (select instanceof ParenthesedSelect) {
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) select;
collectSelects(parenthesedSelect.getPlainSelect(), selects);
List<PlainSelect> plainSelects = new ArrayList<>();
SqlReplaceHelper.getFromSelect(parenthesedSelect, plainSelects);
plainSelects.forEach(s -> collectSelects(s, selects));
}
}

View File

@@ -55,12 +55,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap =
getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
// add as fieldName
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
List<String> asFields = SqlAsHelper.getAsFields(sqlInfo.getCorrectedS2SQL());
for (String asField : asFields) {
fieldNameMap.put(asField, asField);
}
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectedS2SQL(), fieldNameMap);
sqlInfo.setCorrectedS2SQL(sql);
}

View File

@@ -42,13 +42,15 @@ class SchemaCorrectorTest {
SchemaCorrector schemaCorrector = new SchemaCorrector();
schemaCorrector.correct(chatQueryContext, parseInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放量) DESC LIMIT 10",
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放量) DESC LIMIT 10",
parseInfo.getSqlInfo().getCorrectedS2SQL());
}
@Test
void testRemoveUnmappedFilterValue() throws JsonProcessingException {
String sql = "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10";
String sql =
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10";
ChatQueryContext chatQueryContext = buildQueryContext(sql);
SemanticParseInfo parseInfo = chatQueryContext.getCandidateQueries().get(0).getParseInfo();
@@ -60,7 +62,8 @@ class SchemaCorrectorTest {
SchemaCorrector schemaCorrector = new SchemaCorrector();
schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY 播放量 DESC LIMIT 10",
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY 播放量 DESC LIMIT 10",
parseInfo.getSqlInfo().getCorrectedS2SQL());
List<LLMReq.ElementValue> linkingValues = new ArrayList<>();
@@ -74,7 +77,8 @@ class SchemaCorrectorTest {
parseInfo.getSqlInfo().setCorrectedS2SQL(sql);
parseInfo.getSqlInfo().setParsedS2SQL(sql);
schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
Assert.assertEquals(
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10",
parseInfo.getSqlInfo().getCorrectedS2SQL());
}