mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(fix)(headless)Fix schema corrector in that aliases should not be replaced.
This commit is contained in:
@@ -146,6 +146,10 @@ public class SqlReplaceHelper {
|
||||
public static String replaceFields(String sql, Map<String, String> fieldNameMap,
|
||||
boolean exactReplace) {
|
||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||
// alias field should not be replaced
|
||||
Set<String> aliases = SqlSelectHelper.getAliasFields(sql);
|
||||
aliases.forEach(alias -> fieldNameMap.put(alias, alias));
|
||||
|
||||
Set<Select> plainSelectList = SqlSelectHelper.getAllSelect(selectStatement);
|
||||
for (Select plainSelect : plainSelectList) {
|
||||
if (plainSelect instanceof PlainSelect) {
|
||||
|
||||
@@ -41,6 +41,27 @@ class AggCorrectorTest {
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSchemaCorrector() {
|
||||
SchemaCorrector corrector = new SchemaCorrector();
|
||||
Long dataSetId = 1L;
|
||||
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SchemaElement dataSet = new SchemaElement();
|
||||
dataSet.setDataSetId(dataSetId);
|
||||
semanticParseInfo.setDataSet(dataSet);
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
String sql = "WITH 总停留时长 AS (SELECT 用户, SUM(停留时长) AS _总停留时长_ FROM 超音数数据集 WHERE 用户 IN ('alice', 'lucy') AND 数据日期 >= '2025-03-01' AND 数据日期 <= '2025-03-12' GROUP BY 用户) SELECT 用户, _总停留时长_ FROM 总停留时长";
|
||||
sqlInfo.setParsedS2SQL(sql);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
corrector.correct(chatQueryContext, semanticParseInfo);
|
||||
Assert.assertEquals(
|
||||
"WITH 总停留时长 AS (SELECT 用户名, SUM(停留时长) AS _总停留时长_ FROM 超音数数据集 WHERE 用户名 IN ('alice', 'lucy') AND 数据日期 " +
|
||||
">= '2025-03-01' AND 数据日期 <= '2025-03-12' GROUP BY 用户名) SELECT 用户名, _总停留时长_ FROM 总停留时长",
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
}
|
||||
|
||||
private ChatQueryContext buildQueryContext(Long dataSetId) {
|
||||
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||
@@ -51,18 +72,30 @@ class AggCorrectorTest {
|
||||
schemaElement.setDataSetId(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setDataSetId(1L);
|
||||
element1.setName("部门");
|
||||
dimensions.add(element1);
|
||||
|
||||
dimensions.add(SchemaElement.builder()
|
||||
.dataSetId(1L)
|
||||
.name("部门")
|
||||
.build());
|
||||
|
||||
dimensions.add(SchemaElement.builder()
|
||||
.dataSetId(1L)
|
||||
.name("用户名")
|
||||
.build());
|
||||
|
||||
dataSetSchema.setDimensions(dimensions);
|
||||
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
SchemaElement metric1 = new SchemaElement();
|
||||
metric1.setDataSetId(1L);
|
||||
metric1.setName("访问次数");
|
||||
metrics.add(metric1);
|
||||
|
||||
metrics.add(SchemaElement.builder()
|
||||
.dataSetId(1L)
|
||||
.name("访问次数")
|
||||
.build());
|
||||
|
||||
metrics.add(SchemaElement.builder()
|
||||
.dataSetId(1L)
|
||||
.name("停留时长")
|
||||
.build());
|
||||
|
||||
dataSetSchema.setMetrics(metrics);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
@@ -333,7 +333,6 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
metricReq.setSensitiveLevel(SensitiveLevelEnum.HIGH.getCode());
|
||||
metricReq.setDescription("每个用户平均访问的次数");
|
||||
metricReq.setClassifications(Collections.singletonList("核心指标"));
|
||||
metricReq.setAlias("平均访问次数");
|
||||
MetricDefineByMetricParams metricTypeParams = new MetricDefineByMetricParams();
|
||||
metricTypeParams.setExpr("pv/uv");
|
||||
List<MetricParam> metrics = new ArrayList<>();
|
||||
|
||||
@@ -155,16 +155,6 @@ public class Text2SQLEval extends BaseTest {
|
||||
assert result.getTextResult().contains("3");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_detail_query() throws Exception {
|
||||
long start = System.currentTimeMillis();
|
||||
QueryResult result = submitNewChat("特斯拉旗下有哪些品牌", agent.getId());
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() >= 1;
|
||||
assert result.getTextResult().contains("Model Y");
|
||||
assert result.getTextResult().contains("Model 3");
|
||||
}
|
||||
|
||||
public Agent getLLMAgent() {
|
||||
Agent agent = new Agent();
|
||||
agent.setName("Agent for Test");
|
||||
|
||||
Reference in New Issue
Block a user