From f0b4eb46cfd69220e4cf3ab2a85e316f319dbdc1 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Tue, 9 Jul 2024 10:48:48 +0800 Subject: [PATCH] (improvement)(headless)Remove unnecessary `TranslateSqlReq`, use `SemanticQueryReq` instead. --- .../chat/server/executor/SqlExecutor.java | 6 ++-- .../chat/server/parser/NL2SQLParser.java | 2 +- .../supersonic/headless/api/pojo/SqlInfo.java | 11 +++++-- .../api/pojo/request/QueryStructReq.java | 2 +- .../api/pojo/request/TranslateSqlReq.java | 20 ------------ ...teResp.java => SemanticTranslateResp.java} | 6 ++-- .../headless/chat/corrector/AggCorrector.java | 2 +- .../chat/corrector/BaseSemanticCorrector.java | 6 ++-- .../chat/corrector/GrammarCorrector.java | 4 +-- .../chat/corrector/GroupByCorrector.java | 6 ++-- .../chat/corrector/HavingCorrector.java | 8 ++--- .../chat/corrector/SchemaCorrector.java | 24 +++++++------- .../chat/corrector/SelectCorrector.java | 6 ++-- .../chat/corrector/TimeCorrector.java | 16 +++++----- .../chat/corrector/WhereCorrector.java | 8 ++--- .../headless/chat/parser/QueryTypeParser.java | 8 ++--- .../chat/parser/llm/LLMResponseService.java | 2 +- .../chat/query/BaseSemanticQuery.java | 4 +-- .../chat/query/llm/s2sql/LLMSqlQuery.java | 2 +- .../headless/chat/utils/QueryReqBuilder.java | 4 +-- .../chat/corrector/AggCorrectorTest.java | 6 ++-- .../chat/corrector/SchemaCorrectorTest.java | 12 +++---- .../chat/corrector/SelectCorrectorTest.java | 6 ++-- .../chat/corrector/TimeCorrectorTest.java | 32 +++++++++---------- .../chat/corrector/WhereCorrectorTest.java | 4 +-- .../server/aspect/S2DataPermissionAspect.java | 4 --- .../facade/service/SemanticLayerService.java | 7 ++-- .../service/impl/ChatQueryServiceImpl.java | 30 ++++++++--------- .../service/impl/S2SemanticLayerService.java | 12 +++---- .../server/processor/ParseInfoProcessor.java | 10 +++--- .../server/utils/ChatWorkflowEngine.java | 14 +++----- .../supersonic/headless/TranslateTest.java | 30 ++++++----------- 32 files changed, 138 insertions(+), 176 deletions(-) delete mode 100644 headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TranslateSqlReq.java rename headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/{TranslateResp.java => SemanticTranslateResp.java} (78%) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java index a6b938b09..2f1d5c449 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java @@ -44,7 +44,7 @@ public class SqlExecutor implements ChatExecutor { .agentId(chatExecuteContext.getAgentId()) .status(MemoryStatus.PENDING) .question(chatExecuteContext.getQueryText()) - .s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getS2SQL()) + .s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getParsedS2SQL()) .dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo())) .createdBy(chatExecuteContext.getUser().getName()) .updatedBy(chatExecuteContext.getUser().getName()) @@ -64,12 +64,12 @@ public class SqlExecutor implements ChatExecutor { ChatContext chatCtx = chatContextService.getOrCreateContext(chatExecuteContext.getChatId()); SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo(); if (Objects.isNull(parseInfo.getSqlInfo()) - || StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) { + || StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) { return null; } QuerySqlReq sqlReq = QuerySqlReq.builder() - .sql(parseInfo.getSqlInfo().getCorrectS2SQL()) + .sql(parseInfo.getSqlInfo().getCorrectedS2SQL()) .build(); sqlReq.setSqlInfo(parseInfo.getSqlInfo()); sqlReq.setDataSetId(parseInfo.getDataSetId()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index c5d4534db..5ac876438 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -161,7 +161,7 @@ public class NL2SQLParser implements ChatParser { String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId)); String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches()); - String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectS2SQL(); + String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectedS2SQL(); String rewrittenQuery = rewriteQuery(RewriteContext.builder() .curtQuestion(currentMapResult.getQueryText()) .histQuestion(lastParseResult.getQueryText()) diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlInfo.java index f3c1bbea2..596cfe873 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SqlInfo.java @@ -5,8 +5,13 @@ import lombok.Data; @Data public class SqlInfo { - private String s2SQL; - private String correctS2SQL; + // S2SQL generated by semantic parsers + private String parsedS2SQL; + + // S2SQL corrected by semantic correctors + private String correctedS2SQL; + + // SQL to be executed finally private String querySQL; - private String sourceId; + } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java index 077613a67..81d60d66a 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java @@ -164,7 +164,7 @@ public class QueryStructReq extends SemanticQueryReq { result.setDataSetId(this.getDataSetId()); result.setModelIds(this.getModelIdSet()); result.setParams(new ArrayList<>()); - result.getSqlInfo().setCorrectS2SQL(sql); + result.getSqlInfo().setCorrectedS2SQL(sql); return result; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TranslateSqlReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TranslateSqlReq.java deleted file mode 100644 index 1e41a9435..000000000 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TranslateSqlReq.java +++ /dev/null @@ -1,20 +0,0 @@ -package com.tencent.supersonic.headless.api.pojo.request; - -import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; -import lombok.ToString; - -@Data -@ToString -@Builder -@AllArgsConstructor -@NoArgsConstructor -public class TranslateSqlReq { - - private QueryMethod queryTypeEnum; - - private T queryReq; -} diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/TranslateResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/SemanticTranslateResp.java similarity index 78% rename from headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/TranslateResp.java rename to headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/SemanticTranslateResp.java index ffda70269..3a643b07b 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/TranslateResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/SemanticTranslateResp.java @@ -13,10 +13,12 @@ import java.io.Serializable; @Builder @AllArgsConstructor @NoArgsConstructor -public class TranslateResp implements Serializable { +public class SemanticTranslateResp implements Serializable { + + private String querySQL; - private String sql; private boolean isOk; + private String errMsg; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java index 3079df3af..a026726d7 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java @@ -22,7 +22,7 @@ public class AggCorrector extends BaseSemanticCorrector { private void addAggregate(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { List sqlGroupByFields = SqlSelectHelper.getGroupByFields( - semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); if (CollectionUtils.isEmpty(sqlGroupByFields)) { return; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java index bca8709c2..4a409c0d4 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java @@ -29,7 +29,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { try { - if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) { + if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectedS2SQL())) { return; } doCorrect(chatQueryContext, semanticParseInfo); @@ -74,7 +74,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { //add aggregate to all metric - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); Long dataSetId = semanticParseInfo.getDataSet().getDataSet(); List metrics = getMetricElements(chatQueryContext, dataSetId); @@ -98,7 +98,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { return; } String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(aggregateSql); } protected List getMetricElements(ChatQueryContext chatQueryContext, Long dataSetId) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GrammarCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GrammarCorrector.java index 5bfd09400..86e1a600c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GrammarCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GrammarCorrector.java @@ -34,8 +34,8 @@ public class GrammarCorrector extends BaseSemanticCorrector { } public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) { - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java index 539044af6..f58e894d1 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java @@ -35,7 +35,7 @@ public class GroupByCorrector extends BaseSemanticCorrector { Long dataSetId = semanticParseInfo.getDataSetId(); //add dimension group by SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - String correctS2SQL = sqlInfo.getCorrectS2SQL(); + String correctS2SQL = sqlInfo.getCorrectedS2SQL(); SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); // check has distinct if (SqlSelectHelper.hasDistinct(correctS2SQL)) { @@ -68,7 +68,7 @@ public class GroupByCorrector extends BaseSemanticCorrector { Long dataSetId = semanticParseInfo.getDataSetId(); //add dimension group by SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - String correctS2SQL = sqlInfo.getCorrectS2SQL(); + String correctS2SQL = sqlInfo.getCorrectedS2SQL(); SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); //add alias field name Set dimensions = getDimensions(dataSetId, semanticSchema); @@ -83,7 +83,7 @@ public class GroupByCorrector extends BaseSemanticCorrector { return true; }) .collect(Collectors.toSet()); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields)); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields)); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java index 73da6623e..6d99f8a1c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java @@ -49,19 +49,19 @@ public class HavingCorrector extends BaseSemanticCorrector { if (CollectionUtils.isEmpty(metrics)) { return; } - String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql); + String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(), metrics); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(havingSql); } private void addHavingToSelect(SemanticParseInfo semanticParseInfo) { - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) { return; } List havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL); if (!CollectionUtils.isEmpty(havingExpressionList)) { String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(replaceSql); } return; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java index ad80a6b11..e06558f20 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java @@ -50,21 +50,21 @@ public class SchemaCorrector extends BaseSemanticCorrector { private void correctAggFunction(SemanticParseInfo semanticParseInfo) { Map aggregateEnum = AggregateEnum.getAggregateEnum(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum); - sqlInfo.setCorrectS2SQL(sql); + String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectedS2SQL(), aggregateEnum); + sqlInfo.setCorrectedS2SQL(sql); } private void replaceAlias(SemanticParseInfo semanticParseInfo) { SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL()); - sqlInfo.setCorrectS2SQL(replaceAlias); + String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectedS2SQL()); + sqlInfo.setCorrectedS2SQL(replaceAlias); } private void correctFieldName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { Map fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId()); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap); - sqlInfo.setCorrectS2SQL(sql); + String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectedS2SQL(), fieldNameMap); + sqlInfo.setCorrectedS2SQL(sql); } private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) { @@ -79,8 +79,8 @@ public class SchemaCorrector extends BaseSemanticCorrector { SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames); - sqlInfo.setCorrectS2SQL(sql); + String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectedS2SQL(), fieldValueToFieldNames); + sqlInfo.setCorrectedS2SQL(sql); } private List getLinkingValues(SemanticParseInfo semanticParseInfo) { @@ -111,14 +111,14 @@ public class SchemaCorrector extends BaseSemanticCorrector { ))); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false); - sqlInfo.setCorrectS2SQL(sql); + String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectedS2SQL(), filedNameToValueMap, false); + sqlInfo.setCorrectedS2SQL(sql); } public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - String correctS2SQL = sqlInfo.getCorrectS2SQL(); + String correctS2SQL = sqlInfo.getCorrectedS2SQL(); List whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL); if (CollectionUtils.isEmpty(whereExpressionList)) { return; @@ -143,7 +143,7 @@ public class SchemaCorrector extends BaseSemanticCorrector { .map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet()); String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java index c48e2e7b5..88bfe17a1 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java @@ -33,7 +33,7 @@ public class SelectCorrector extends BaseSemanticCorrector { @Override public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); List aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL); List selectFields = SqlSelectHelper.getSelectFields(correctS2SQL); // If the number of aggregated fields is equal to the number of queried fields, do not add fields to select. @@ -44,7 +44,7 @@ public class SelectCorrector extends BaseSemanticCorrector { } correctS2SQL = addFieldsToSelect(chatQueryContext, semanticParseInfo, correctS2SQL); String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(querySql); } protected String addFieldsToSelect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo, @@ -65,7 +65,7 @@ public class SelectCorrector extends BaseSemanticCorrector { } needAddFields.removeAll(selectFields); String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields)); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(addFieldsToSelectSql); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql); return addFieldsToSelectSql; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java index fc892c3a5..83395f425 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java @@ -45,7 +45,7 @@ public class TimeCorrector extends BaseSemanticCorrector { } private void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); //decide whether remove date field from where Environment environment = ContextUtils.getBean(Environment.class); String correctorDate = environment.getProperty("s2.corrector.date"); @@ -55,12 +55,12 @@ public class TimeCorrector extends BaseSemanticCorrector { removeFieldNames.add(TimeDimensionEnum.WEEK.getChName()); removeFieldNames.add(TimeDimensionEnum.MONTH.getChName()); correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } } private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); List whereFields = SqlSelectHelper.getWhereFields(correctS2SQL); //decide whether add date field to where @@ -88,11 +88,11 @@ public class TimeCorrector extends BaseSemanticCorrector { } } } - semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) { - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL); if (Objects.isNull(dateBoundInfo)) { return; @@ -108,14 +108,14 @@ public class TimeCorrector extends BaseSemanticCorrector { } catch (JSQLParserException e) { log.error("parseCondExpression", e); } - semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } } private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) { - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java index bfd2d94e9..4c8421598 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java @@ -39,7 +39,7 @@ public class WhereCorrector extends BaseSemanticCorrector { protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters()); - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); if (StringUtils.isNotEmpty(queryFilter)) { log.info("add queryFilter to correctS2SQL :{}", queryFilter); @@ -50,7 +50,7 @@ public class WhereCorrector extends BaseSemanticCorrector { log.error("parseCondExpression", e); } correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } } @@ -71,9 +71,9 @@ public class WhereCorrector extends BaseSemanticCorrector { } Map> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions); - String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), + String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(), aliasAndBizNameToTechName); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } private Map> getAliasAndBizNameToTechName(List dimensions) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java index 3d35587c7..c1e29523e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java @@ -46,14 +46,14 @@ public class QueryTypeParser implements SemanticParser { private QueryType getQueryType(ChatQueryContext chatQueryContext, SemanticQuery semanticQuery) { SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); SqlInfo sqlInfo = parseInfo.getSqlInfo(); - if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) { + if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getParsedS2SQL())) { return QueryType.DETAIL; } //1. entity queryType Long dataSetId = parseInfo.getDataSetId(); SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) { - List whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL()); + List whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getParsedS2SQL()); List whereFilterByTimeFields = filterByTimeFields(whereFields); if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) { Set ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName) @@ -63,7 +63,7 @@ public class QueryTypeParser implements SemanticParser { return QueryType.ID; } } - List selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL()); + List selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL()); selectFields.addAll(whereFields); List selectWhereFilterByTimeFields = filterByTimeFields(selectFields); if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) { @@ -91,7 +91,7 @@ public class QueryTypeParser implements SemanticParser { } private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) { - List selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL()); + List selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL()); List metrics = semanticSchema.getMetrics(dataSetId); if (CollectionUtils.isNotEmpty(metrics)) { Set metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java index 4419a3217..53eb2a910 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java @@ -38,7 +38,7 @@ public class LLMResponseService { parseInfo.setProperties(properties); parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight)); parseInfo.setQueryMode(semanticQuery.getQueryMode()); - parseInfo.getSqlInfo().setS2SQL(s2SQL); + parseInfo.getSqlInfo().setParsedS2SQL(s2SQL); queryCtx.getCandidateQueries().add(semanticQuery); return parseInfo; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java index 0a14d4bef..b479633ac 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java @@ -83,8 +83,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable { QueryStructReq queryStructReq = convertQueryStruct(); convertBizNameToName(semanticSchema, queryStructReq); QuerySqlReq querySQLReq = queryStructReq.convert(); - parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql()); - parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql()); + parseInfo.getSqlInfo().setParsedS2SQL(querySQLReq.getSql()); + parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql()); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java index f2add7ff2..35e3e275e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMSqlQuery.java @@ -33,6 +33,6 @@ public class LLMSqlQuery extends LLMSemanticQuery { @Override public void initS2Sql(SemanticSchema semanticSchema, User user) { SqlInfo sqlInfo = parseInfo.getSqlInfo(); - sqlInfo.setCorrectS2SQL(sqlInfo.getS2SQL()); + sqlInfo.setCorrectedS2SQL(sqlInfo.getParsedS2SQL()); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java index ad7591d35..2fbaa2f89 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java @@ -152,8 +152,8 @@ public class QueryReqBuilder { public static QuerySqlReq buildS2SQLReq(SqlInfo sqlInfo, Long dataSetId) { QuerySqlReq querySQLReq = new QuerySqlReq(); - if (Objects.nonNull(sqlInfo.getCorrectS2SQL())) { - querySQLReq.setSql(sqlInfo.getCorrectS2SQL()); + if (Objects.nonNull(sqlInfo.getCorrectedS2SQL())) { + querySQLReq.setSql(sqlInfo.getCorrectedS2SQL()); } querySQLReq.setSqlInfo(sqlInfo); querySQLReq.setDataSetId(dataSetId); diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/AggCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/AggCorrectorTest.java index 013b01b18..6a6005892 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/AggCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/AggCorrectorTest.java @@ -30,14 +30,14 @@ class AggCorrectorTest { String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND" + " datediff('day', 数据日期, '2024-06-04') <= 7" + " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1"; - sqlInfo.setS2SQL(sql); - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setParsedS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); corrector.correct(chatQueryContext, semanticParseInfo); Assert.assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'" + " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户" + " ORDER BY SUM(访问次数) DESC LIMIT 1", - semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); } private ChatQueryContext buildQueryContext(Long dataSetId) { diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java index ce1ee737e..afd3b00c8 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java @@ -65,8 +65,8 @@ class SchemaCorrectorTest { + "and 商务组 = 'xxx' order by 播放量 desc limit 10"; SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SqlInfo sqlInfo = new SqlInfo(); - sqlInfo.setS2SQL(sql); - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setParsedS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); SchemaElement schemaElement = new SchemaElement(); @@ -80,7 +80,7 @@ class SchemaCorrectorTest { schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo); Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " - + "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + + "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); parseResult = objectMapper.readValue(json, ParseResult.class); @@ -92,11 +92,11 @@ class SchemaCorrectorTest { parseResult.setLinkingValues(linkingValues); semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult); - semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql); - semanticParseInfo.getSqlInfo().setS2SQL(sql); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql); + semanticParseInfo.getSqlInfo().setParsedS2SQL(sql); schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo); Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " - + "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + + "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); } diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java index 669524c74..1d8d0256a 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java @@ -44,12 +44,12 @@ class SelectCorrectorTest { semanticParseInfo.setQueryType(QueryType.DETAIL); SqlInfo sqlInfo = new SqlInfo(); String sql = "SELECT * FROM 艺人库 WHERE 艺人名='周杰伦'"; - sqlInfo.setS2SQL(sql); - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setParsedS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); corrector.correct(chatQueryContext, semanticParseInfo); Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'", - semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); } private ChatQueryContext buildQueryContext(Long dataSetId) { diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrectorTest.java index b540ef1d5..b5be7a633 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrectorTest.java @@ -19,84 +19,84 @@ class TimeCorrectorTest { //1.数据日期 <= String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') " + "AND 数据日期 >= '2023-11-17' GROUP BY 维度1", - sqlInfo.getCorrectS2SQL()); + sqlInfo.getCorrectedS2SQL()); //2.数据日期 < sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1"; - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') " + "AND 数据日期 >= '2023-11-17' GROUP BY 维度1", - sqlInfo.getCorrectS2SQL()); + sqlInfo.getCorrectedS2SQL()); //3.数据日期 >= sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1"; - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1", - sqlInfo.getCorrectS2SQL()); + sqlInfo.getCorrectedS2SQL()); //4.数据日期 > sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1"; - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1", - sqlInfo.getCorrectS2SQL()); + sqlInfo.getCorrectedS2SQL()); //5.no 数据日期 sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE 歌手名 = '张三' GROUP BY 维度1"; - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1", - sqlInfo.getCorrectS2SQL()); + sqlInfo.getCorrectedS2SQL()); //6. 数据日期-月 <= sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1"; - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') " + "AND 数据日期_月 >= '2024-01' GROUP BY 维度1", - sqlInfo.getCorrectS2SQL()); + sqlInfo.getCorrectedS2SQL()); //7. 数据日期-月 > sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1"; - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1", - sqlInfo.getCorrectS2SQL()); + sqlInfo.getCorrectedS2SQL()); //8. no where sql = "SELECT COUNT(1) FROM 数据库"; - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); - Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL()); + Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectedS2SQL()); } } diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java index 7b2a5e811..904b4325e 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java @@ -19,7 +19,7 @@ class WhereCorrectorTest { SqlInfo sqlInfo = new SqlInfo(); String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; - sqlInfo.setCorrectS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); ChatQueryContext chatQueryContext = new ChatQueryContext(); @@ -54,7 +54,7 @@ class WhereCorrectorTest { WhereCorrector whereCorrector = new WhereCorrector(); whereCorrector.addQueryFilter(chatQueryContext, semanticParseInfo); - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); Assert.assertEquals(correctS2SQL, "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE " + "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND " diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java index b65b8850b..82da75b1d 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java @@ -19,7 +19,6 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq; import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticSchemaResp; @@ -73,9 +72,6 @@ public class S2DataPermissionAspect { SemanticQueryReq queryReq = null; if (objects[0] instanceof SemanticQueryReq) { queryReq = (SemanticQueryReq) objects[0]; - } else if (objects[0] instanceof TranslateSqlReq) { - queryReq = (SemanticQueryReq) ((TranslateSqlReq) objects[0]).getQueryReq(); - needQueryData = false; } if (queryReq == null) { throw new InvalidArgumentException("queryReq is not Invalid"); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/SemanticLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/SemanticLayerService.java index 22adb9217..27e569756 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/SemanticLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/SemanticLayerService.java @@ -4,10 +4,9 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.EntityInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import com.tencent.supersonic.headless.api.pojo.response.TranslateResp; +import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.api.pojo.response.ItemResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; @@ -20,12 +19,12 @@ public interface SemanticLayerService { DataSetSchema getDataSetSchema(Long id); + SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception; + SemanticQueryResp queryByReq(SemanticQueryReq queryReq, User user) throws Exception; SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user); - TranslateResp translate(TranslateSqlReq translateSqlReq, User user) throws Exception; - EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user); List getDomainDataSetTree(); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/ChatQueryServiceImpl.java index 8f561f6ea..c78260090 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/ChatQueryServiceImpl.java @@ -26,11 +26,9 @@ import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SqlEvaluation; import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.enums.CostType; -import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; -import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; @@ -41,7 +39,7 @@ import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.response.DataSetMapInfo; import com.tencent.supersonic.headless.api.pojo.response.DataSetResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; -import com.tencent.supersonic.headless.api.pojo.response.TranslateResp; +import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; @@ -247,8 +245,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { List fields = new ArrayList<>(); if (Objects.nonNull(parseInfo.getSqlInfo()) - && StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectS2SQL())) { - String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); + && StringUtils.isNotBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) { + String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); fields = SqlSelectHelper.getAllFields(correctorSql); } if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode()) @@ -260,13 +258,11 @@ public class ChatQueryServiceImpl implements ChatQueryService { } else if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(parseInfo.getQueryMode())) { log.info("llm begin revise filters!"); String correctorSql = reviseCorrectS2SQL(queryData, parseInfo); - parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql); + parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); semanticQuery.setParseInfo(parseInfo); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); - TranslateSqlReq translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq) - .queryTypeEnum(QueryMethod.SQL).build(); - TranslateResp explain = semanticLayerService.translate(translateSqlReq, user); - parseInfo.getSqlInfo().setQuerySQL(explain.getSql()); + SemanticTranslateResp explain = semanticLayerService.translate(semanticQueryReq, user); + parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); } else { log.info("rule begin replace metrics and revise filters!"); //remove unvalid filters @@ -303,7 +299,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { Map> filedNameToValueMap = new HashMap<>(); Map> havingFiledNameToValueMap = new HashMap<>(); - String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); + String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); log.info("correctorSql before replacing:{}", correctorSql); // get where filter and having filter List whereExpressionList = SqlSelectHelper.getWhereExpressions(correctorSql); @@ -334,7 +330,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) { List oriMetrics = parseInfo.getMetrics().stream() .map(SchemaElement::getName).collect(Collectors.toList()); - String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); + String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); log.info("before replaceMetrics:{}", correctorSql); log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric); Map> fieldMap = new HashMap<>(); @@ -343,7 +339,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { correctorSql = SqlReplaceHelper.replaceAggFields(correctorSql, fieldMap); } log.info("after replaceMetrics:{}", correctorSql); - parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql); + parseInfo.getSqlInfo().setCorrectedS2SQL(correctorSql); } private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo, @@ -598,7 +594,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { public void correct(QuerySqlReq querySqlReq, User user) { SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user); - querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectS2SQL()); + querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); } @Override @@ -613,8 +609,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { queryCtx.setSemanticSchema(semanticSchema); SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SqlInfo sqlInfo = new SqlInfo(); - sqlInfo.setCorrectS2SQL(querySqlReq.getSql()); - sqlInfo.setS2SQL(querySqlReq.getSql()); + sqlInfo.setCorrectedS2SQL(querySqlReq.getSql()); + sqlInfo.setParsedS2SQL(querySqlReq.getSql()); semanticParseInfo.setSqlInfo(sqlInfo); semanticParseInfo.setQueryType(QueryType.DETAIL); @@ -630,7 +626,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { corrector.correct(queryCtx, semanticParseInfo); } }); - log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectS2SQL()); + log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectedS2SQL()); return semanticParseInfo; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java index fb8c60114..2f4ed0616 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java @@ -19,7 +19,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; -import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; @@ -28,7 +27,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; -import com.tencent.supersonic.headless.api.pojo.response.TranslateResp; +import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.api.pojo.response.ItemResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; @@ -228,12 +227,11 @@ public class S2SemanticLayerService implements SemanticLayerService { @S2DataPermission @Override - public TranslateResp translate(TranslateSqlReq translateSqlReq, User user) throws Exception { - T queryReq = translateSqlReq.getQueryReq(); - QueryStatement queryStatement = buildQueryStatement((SemanticQueryReq) queryReq, user); + public SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception { + QueryStatement queryStatement = buildQueryStatement(queryReq, user); semanticTranslator.translate(queryStatement); - return TranslateResp.builder() - .sql(queryStatement.getSql()) + return SemanticTranslateResp.builder() + .querySQL(queryStatement.getSql()) .isOk(queryStatement.isOk()) .errMsg(queryStatement.getErrMsg()) .build(); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java index e7408fb81..d7f06a100 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java @@ -52,12 +52,12 @@ public class ParseInfoProcessor implements ResultProcessor { public void updateParseInfo(SemanticParseInfo parseInfo) { SqlInfo sqlInfo = parseInfo.getSqlInfo(); - String correctS2SQL = sqlInfo.getCorrectS2SQL(); + String correctS2SQL = sqlInfo.getCorrectedS2SQL(); if (StringUtils.isBlank(correctS2SQL)) { return; } // if S2SQL equals correctS2SQL, then not update the parseInfo. - if (correctS2SQL.equals(sqlInfo.getS2SQL())) { + if (correctS2SQL.equals(sqlInfo.getParsedS2SQL())) { return; } List expressions = SqlSelectHelper.getFilterExpression(correctS2SQL); @@ -87,15 +87,15 @@ public class ParseInfoProcessor implements ResultProcessor { if (Objects.isNull(semanticSchema)) { return; } - List allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL())); + List allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectedS2SQL())); Set metrics = getElements(dataSetId, allFields, semanticSchema.getMetrics()); parseInfo.setMetrics(metrics); if (QueryType.METRIC.equals(parseInfo.getQueryType())) { - List groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL()); + List groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectedS2SQL()); List groupByDimensions = getFieldsExceptDate(groupByFields); parseInfo.setDimensions(getElements(dataSetId, groupByDimensions, semanticSchema.getDimensions())); } else if (QueryType.DETAIL.equals(parseInfo.getQueryType())) { - List selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL()); + List selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectedS2SQL()); List selectDimensions = getFieldsExceptDate(selectFields); parseInfo.setDimensions(getElements(dataSetId, selectDimensions, semanticSchema.getDimensions())); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java index dbb80cbd1..5914cedd4 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -4,11 +4,9 @@ import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState; -import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; -import com.tencent.supersonic.headless.api.pojo.response.TranslateResp; +import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.chat.ChatContext; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector; @@ -123,15 +121,13 @@ public class ChatWorkflowEngine { semanticQuery.setParseInfo(parseInfo); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class); - TranslateSqlReq translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq) - .queryTypeEnum(QueryMethod.SQL).build(); - TranslateResp explain = queryService.translate(translateSqlReq, chatQueryContext.getUser()); - parseInfo.getSqlInfo().setQuerySQL(explain.getSql()); + SemanticTranslateResp explain = queryService.translate(semanticQueryReq, chatQueryContext.getUser()); + parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); keyPipelineLog.info("SqlInfoProcessor results:\n" + "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}", - StringUtils.normalizeSpace(parseInfo.getSqlInfo().getS2SQL()), - StringUtils.normalizeSpace(parseInfo.getSqlInfo().getCorrectS2SQL()), + StringUtils.normalizeSpace(parseInfo.getSqlInfo().getParsedS2SQL()), + StringUtils.normalizeSpace(parseInfo.getSqlInfo().getCorrectedS2SQL()), StringUtils.normalizeSpace(parseInfo.getSqlInfo().getQuerySQL())); } catch (Exception e) { log.warn("get sql info failed:{}", parseInfo, e); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslateTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslateTest.java index cdc55d4f8..b8c4cd361 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslateTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslateTest.java @@ -2,11 +2,8 @@ package com.tencent.supersonic.headless; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; -import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq; -import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; -import com.tencent.supersonic.headless.api.pojo.response.TranslateResp; +import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import com.tencent.supersonic.util.DataUtils; import org.junit.jupiter.api.Test; @@ -21,29 +18,22 @@ public class TranslateTest extends BaseTest { @Test public void testSqlExplain() throws Exception { String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "; - TranslateSqlReq translateSqlReq = TranslateSqlReq.builder() - .queryTypeEnum(QueryMethod.SQL) - .queryReq(QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView())) - .build(); - TranslateResp explain = semanticLayerService.translate(translateSqlReq, User.getFakeUser()); + SemanticTranslateResp explain = semanticLayerService.translate(QueryReqBuilder.buildS2SQLReq(sql, + DataUtils.getMetricAgentView()), User.getFakeUser()); assertNotNull(explain); - assertNotNull(explain.getSql()); - assertTrue(explain.getSql().contains("department")); - assertTrue(explain.getSql().contains("pv")); + assertNotNull(explain.getQuerySQL()); + assertTrue(explain.getQuerySQL().contains("department")); + assertTrue(explain.getQuerySQL().contains("pv")); } @Test public void testStructExplain() throws Exception { QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department")); - TranslateSqlReq translateSqlReq = TranslateSqlReq.builder() - .queryTypeEnum(QueryMethod.STRUCT) - .queryReq(queryStructReq) - .build(); - TranslateResp explain = semanticLayerService.translate(translateSqlReq, User.getFakeUser()); + SemanticTranslateResp explain = semanticLayerService.translate(queryStructReq, User.getFakeUser()); assertNotNull(explain); - assertNotNull(explain.getSql()); - assertTrue(explain.getSql().contains("department")); - assertTrue(explain.getSql().contains("pv")); + assertNotNull(explain.getQuerySQL()); + assertTrue(explain.getQuerySQL().contains("department")); + assertTrue(explain.getQuerySQL().contains("pv")); } }