mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-29 20:44:25 +08:00
Compare commits
2 Commits
d294fec2a0
...
ef161fe1f2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef161fe1f2 | ||
|
|
0417f12324 |
@@ -666,8 +666,9 @@ public class SqlSelectHelper {
|
|||||||
}
|
}
|
||||||
if (withSelect instanceof ParenthesedSelect) {
|
if (withSelect instanceof ParenthesedSelect) {
|
||||||
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) withSelect;
|
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) withSelect;
|
||||||
PlainSelect withPlainSelect = parenthesedSelect.getPlainSelect();
|
List<PlainSelect> plainSelects = new ArrayList<>();
|
||||||
plainSelectList.add(withPlainSelect);
|
SqlReplaceHelper.getFromSelect(parenthesedSelect, plainSelects);
|
||||||
|
plainSelectList.addAll(plainSelects);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -893,7 +894,9 @@ public class SqlSelectHelper {
|
|||||||
collectSelects(withItem.getSelect(), selects);
|
collectSelects(withItem.getSelect(), selects);
|
||||||
} else if (select instanceof ParenthesedSelect) {
|
} else if (select instanceof ParenthesedSelect) {
|
||||||
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) select;
|
ParenthesedSelect parenthesedSelect = (ParenthesedSelect) select;
|
||||||
collectSelects(parenthesedSelect.getPlainSelect(), selects);
|
List<PlainSelect> plainSelects = new ArrayList<>();
|
||||||
|
SqlReplaceHelper.getFromSelect(parenthesedSelect, plainSelects);
|
||||||
|
plainSelects.forEach(s -> collectSelects(s, selects));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package com.tencent.supersonic.common.pojo.enums;
|
package com.tencent.supersonic.common.pojo.enums;
|
||||||
|
|
||||||
public enum AggOperatorEnum {
|
public enum AggOperatorEnum {
|
||||||
|
ANY("ANY"),
|
||||||
|
|
||||||
MAX("MAX"),
|
MAX("MAX"),
|
||||||
|
|
||||||
MIN("MIN"),
|
MIN("MIN"),
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ public class ColumnSchema {
|
|||||||
|
|
||||||
private FieldType filedType;
|
private FieldType filedType;
|
||||||
|
|
||||||
private AggOperatorEnum agg = AggOperatorEnum.SUM;
|
private AggOperatorEnum agg = AggOperatorEnum.ANY;
|
||||||
|
|
||||||
private String name;
|
private String name;
|
||||||
|
|
||||||
|
|||||||
@@ -55,12 +55,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
SemanticParseInfo semanticParseInfo) {
|
SemanticParseInfo semanticParseInfo) {
|
||||||
Map<String, String> fieldNameMap =
|
Map<String, String> fieldNameMap =
|
||||||
getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
|
getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
|
||||||
// add as fieldName
|
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
List<String> asFields = SqlAsHelper.getAsFields(sqlInfo.getCorrectedS2SQL());
|
|
||||||
for (String asField : asFields) {
|
|
||||||
fieldNameMap.put(asField, asField);
|
|
||||||
}
|
|
||||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectedS2SQL(), fieldNameMap);
|
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectedS2SQL(), fieldNameMap);
|
||||||
sqlInfo.setCorrectedS2SQL(sql);
|
sqlInfo.setCorrectedS2SQL(sql);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,13 +42,15 @@ class SchemaCorrectorTest {
|
|||||||
SchemaCorrector schemaCorrector = new SchemaCorrector();
|
SchemaCorrector schemaCorrector = new SchemaCorrector();
|
||||||
schemaCorrector.correct(chatQueryContext, parseInfo);
|
schemaCorrector.correct(chatQueryContext, parseInfo);
|
||||||
|
|
||||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放量) DESC LIMIT 10",
|
Assert.assertEquals(
|
||||||
|
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放量) DESC LIMIT 10",
|
||||||
parseInfo.getSqlInfo().getCorrectedS2SQL());
|
parseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testRemoveUnmappedFilterValue() throws JsonProcessingException {
|
void testRemoveUnmappedFilterValue() throws JsonProcessingException {
|
||||||
String sql = "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10";
|
String sql =
|
||||||
|
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10";
|
||||||
ChatQueryContext chatQueryContext = buildQueryContext(sql);
|
ChatQueryContext chatQueryContext = buildQueryContext(sql);
|
||||||
SemanticParseInfo parseInfo = chatQueryContext.getCandidateQueries().get(0).getParseInfo();
|
SemanticParseInfo parseInfo = chatQueryContext.getCandidateQueries().get(0).getParseInfo();
|
||||||
|
|
||||||
@@ -60,7 +62,8 @@ class SchemaCorrectorTest {
|
|||||||
SchemaCorrector schemaCorrector = new SchemaCorrector();
|
SchemaCorrector schemaCorrector = new SchemaCorrector();
|
||||||
schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo);
|
schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo);
|
||||||
|
|
||||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY 播放量 DESC LIMIT 10",
|
Assert.assertEquals(
|
||||||
|
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY 播放量 DESC LIMIT 10",
|
||||||
parseInfo.getSqlInfo().getCorrectedS2SQL());
|
parseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||||
|
|
||||||
List<LLMReq.ElementValue> linkingValues = new ArrayList<>();
|
List<LLMReq.ElementValue> linkingValues = new ArrayList<>();
|
||||||
@@ -74,7 +77,8 @@ class SchemaCorrectorTest {
|
|||||||
parseInfo.getSqlInfo().setCorrectedS2SQL(sql);
|
parseInfo.getSqlInfo().setCorrectedS2SQL(sql);
|
||||||
parseInfo.getSqlInfo().setParsedS2SQL(sql);
|
parseInfo.getSqlInfo().setParsedS2SQL(sql);
|
||||||
schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo);
|
schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo);
|
||||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
Assert.assertEquals(
|
||||||
|
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||||
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10",
|
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10",
|
||||||
parseInfo.getSqlInfo().getCorrectedS2SQL());
|
parseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -163,13 +163,11 @@ public class ModelConverter {
|
|||||||
getIdentifyType(fieldType).name(), columnSchema.getColumnName(), 1);
|
getIdentifyType(fieldType).name(), columnSchema.getColumnName(), 1);
|
||||||
modelDetail.getIdentifiers().add(identify);
|
modelDetail.getIdentifiers().add(identify);
|
||||||
} else if (FieldType.measure.equals(fieldType)) {
|
} else if (FieldType.measure.equals(fieldType)) {
|
||||||
Measure measure = new Measure(columnSchema.getName(),
|
Measure measure = new Measure(columnSchema.getName(), columnSchema.getColumnName(),
|
||||||
modelReq.getBizName() + "_" + columnSchema.getColumnName(),
|
|
||||||
columnSchema.getColumnName(), columnSchema.getAgg().getOperator(), 1);
|
columnSchema.getColumnName(), columnSchema.getAgg().getOperator(), 1);
|
||||||
modelDetail.getMeasures().add(measure);
|
modelDetail.getMeasures().add(measure);
|
||||||
} else {
|
} else {
|
||||||
Dimension dim = new Dimension(columnSchema.getName(),
|
Dimension dim = new Dimension(columnSchema.getName(), columnSchema.getColumnName(),
|
||||||
modelReq.getBizName() + "_" + columnSchema.getColumnName(),
|
|
||||||
columnSchema.getColumnName(),
|
columnSchema.getColumnName(),
|
||||||
DimensionType.valueOf(columnSchema.getFiledType().name()), 1);
|
DimensionType.valueOf(columnSchema.getFiledType().name()), 1);
|
||||||
modelDetail.getDimensions().add(dim);
|
modelDetail.getDimensions().add(dim);
|
||||||
|
|||||||
Reference in New Issue
Block a user