diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java index 02238b48b..4f1362cca 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java @@ -666,8 +666,9 @@ public class SqlSelectHelper { } if (withSelect instanceof ParenthesedSelect) { ParenthesedSelect parenthesedSelect = (ParenthesedSelect) withSelect; - PlainSelect withPlainSelect = parenthesedSelect.getPlainSelect(); - plainSelectList.add(withPlainSelect); + List plainSelects = new ArrayList<>(); + SqlReplaceHelper.getFromSelect(parenthesedSelect, plainSelects); + plainSelectList.addAll(plainSelects); } } } @@ -893,7 +894,9 @@ public class SqlSelectHelper { collectSelects(withItem.getSelect(), selects); } else if (select instanceof ParenthesedSelect) { ParenthesedSelect parenthesedSelect = (ParenthesedSelect) select; - collectSelects(parenthesedSelect.getPlainSelect(), selects); + List plainSelects = new ArrayList<>(); + SqlReplaceHelper.getFromSelect(parenthesedSelect, plainSelects); + plainSelects.forEach(s -> collectSelects(s, selects)); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java index 0eb354db6..8f044e431 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java @@ -55,12 +55,7 @@ public class SchemaCorrector extends BaseSemanticCorrector { SemanticParseInfo semanticParseInfo) { Map fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId()); - // add as fieldName SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - List asFields = SqlAsHelper.getAsFields(sqlInfo.getCorrectedS2SQL()); - for (String asField : asFields) { - fieldNameMap.put(asField, asField); - } String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectedS2SQL(), fieldNameMap); sqlInfo.setCorrectedS2SQL(sql); } diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java index 718b60fa6..a96c043aa 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java @@ -42,13 +42,15 @@ class SchemaCorrectorTest { SchemaCorrector schemaCorrector = new SchemaCorrector(); schemaCorrector.correct(chatQueryContext, parseInfo); - Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放量) DESC LIMIT 10", + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放量) DESC LIMIT 10", parseInfo.getSqlInfo().getCorrectedS2SQL()); } @Test void testRemoveUnmappedFilterValue() throws JsonProcessingException { - String sql = "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10"; + String sql = + "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10"; ChatQueryContext chatQueryContext = buildQueryContext(sql); SemanticParseInfo parseInfo = chatQueryContext.getCandidateQueries().get(0).getParseInfo(); @@ -60,7 +62,8 @@ class SchemaCorrectorTest { SchemaCorrector schemaCorrector = new SchemaCorrector(); schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo); - Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY 播放量 DESC LIMIT 10", + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY 播放量 DESC LIMIT 10", parseInfo.getSqlInfo().getCorrectedS2SQL()); List linkingValues = new ArrayList<>(); @@ -74,7 +77,8 @@ class SchemaCorrectorTest { parseInfo.getSqlInfo().setCorrectedS2SQL(sql); parseInfo.getSqlInfo().setParsedS2SQL(sql); schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo); - Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", parseInfo.getSqlInfo().getCorrectedS2SQL()); }