From 03a4719aed135b355b9b757849da0c478bbd6baa Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Thu, 21 Sep 2023 21:57:06 +0800 Subject: [PATCH] [improvement][chat]llm parser corrector is simplified by sql distribution (#120) --- .../chat/api/pojo/SemanticSchema.java | 17 +++++ .../chat/corrector/DateFieldCorrector.java | 27 ------- .../chat/corrector/FieldCorrector.java | 18 ----- .../corrector/FunctionAliasCorrector.java | 16 ---- .../chat/corrector/FunctionCorrector.java | 17 ----- ...ameCorrector.java => GlobalCorrector.java} | 33 ++++++++- .../chat/corrector/GroupByCorrector.java | 15 ++++ .../chat/corrector/HavingCorrector.java | 14 ++++ .../chat/corrector/QueryFilterAppend.java | 48 ------------ ...endCorrector.java => SelectCorrector.java} | 3 +- ...NameCorrector.java => TableCorrector.java} | 2 +- ...alueCorrector.java => WhereCorrector.java} | 73 ++++++++++++++++++- .../chat/parser/llm/dsl/LLMDslParser.java | 53 ++++++++------ .../corrector/DateFieldCorrectorTest.java | 45 ------------ .../corrector/FieldNameCorrectorTest.java | 65 ----------------- .../corrector/FieldValueCorrectorTest.java | 71 ------------------ .../SelectFieldAppendCorrectorTest.java | 46 ------------ .../main/resources/META-INF/spring.factories | 15 ++-- .../main/resources/META-INF/spring.factories | 15 ++-- 19 files changed, 191 insertions(+), 402 deletions(-) delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/corrector/DateFieldCorrector.java delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldCorrector.java delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionAliasCorrector.java delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionCorrector.java rename chat/core/src/main/java/com/tencent/supersonic/chat/corrector/{FieldNameCorrector.java => GlobalCorrector.java} (64%) create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/corrector/QueryFilterAppend.java rename chat/core/src/main/java/com/tencent/supersonic/chat/corrector/{SelectFieldAppendCorrector.java => SelectCorrector.java} (96%) rename chat/core/src/main/java/com/tencent/supersonic/chat/corrector/{TableNameCorrector.java => TableCorrector.java} (91%) rename chat/core/src/main/java/com/tencent/supersonic/chat/corrector/{FieldValueCorrector.java => WhereCorrector.java} (50%) delete mode 100644 chat/core/src/test/java/com/tencent/supersonic/chat/corrector/DateFieldCorrectorTest.java delete mode 100644 chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldNameCorrectorTest.java delete mode 100644 chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldValueCorrectorTest.java delete mode 100644 chat/core/src/test/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrectorTest.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java index 4cf01216f..ebb29035c 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java @@ -7,6 +7,7 @@ import java.util.Map; import java.util.stream.Collectors; public class SemanticSchema implements Serializable { + private List modelSchemaList; public SemanticSchema(List modelSchemaList) { @@ -34,12 +35,28 @@ public class SemanticSchema implements Serializable { return dimensions; } + public List getDimensions(Long modelId) { + List dimensions = getDimensions(); + return getElementsByModelId(modelId, dimensions); + } + public List getMetrics() { List metrics = new ArrayList<>(); modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics())); return metrics; } + public List getMetrics(Long modelId) { + List metrics = getMetrics(); + return getElementsByModelId(modelId, metrics); + } + + private List getElementsByModelId(Long modelId, List elements) { + return elements.stream() + .filter(schemaElement -> modelId.equals(schemaElement.getModel())) + .collect(Collectors.toList()); + } + public List getModels() { List models = new ArrayList<>(); modelSchemaList.stream().forEach(d -> models.add(d.getModel())); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/DateFieldCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/DateFieldCorrector.java deleted file mode 100644 index f4221a8d8..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/DateFieldCorrector.java +++ /dev/null @@ -1,27 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; -import java.util.List; -import lombok.extern.slf4j.Slf4j; -import org.springframework.util.CollectionUtils; - -@Slf4j -public class DateFieldCorrector extends BaseSemanticCorrector { - - @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { - - String sql = semanticCorrectInfo.getSql(); - List whereFields = SqlParserSelectHelper.getWhereFields(sql); - if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DATE_FIELD)) { - String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId()); - sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate); - } - semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql()); - semanticCorrectInfo.setSql(sql); - } - -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldCorrector.java deleted file mode 100644 index 77cb01c3d..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldCorrector.java +++ /dev/null @@ -1,18 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; -import lombok.extern.slf4j.Slf4j; - -@Slf4j -public class FieldCorrector extends BaseSemanticCorrector { - - @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { - String preSql = semanticCorrectInfo.getSql(); - semanticCorrectInfo.setPreSql(preSql); - String sql = SqlParserUpdateHelper.replaceFields(preSql, - getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId())); - semanticCorrectInfo.setSql(sql); - } -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionAliasCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionAliasCorrector.java deleted file mode 100644 index 7564942c4..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionAliasCorrector.java +++ /dev/null @@ -1,16 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; -import lombok.extern.slf4j.Slf4j; - -@Slf4j -public class FunctionAliasCorrector extends BaseSemanticCorrector { - - @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { - String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql()); - semanticCorrectInfo.setSql(replaceAlias); - } - -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionCorrector.java deleted file mode 100644 index e0a3a3210..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionCorrector.java +++ /dev/null @@ -1,17 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; -import lombok.extern.slf4j.Slf4j; - -@Slf4j -public class FunctionCorrector extends BaseSemanticCorrector { - - @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { - String preSql = semanticCorrectInfo.getSql(); - semanticCorrectInfo.setPreSql(preSql); - String sql = SqlParserUpdateHelper.replaceFunction(preSql); - semanticCorrectInfo.setSql(sql); - } -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldNameCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java similarity index 64% rename from chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldNameCorrector.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java index f94b98253..774e12aeb 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldNameCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java @@ -16,11 +16,39 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; @Slf4j -public class FieldNameCorrector extends BaseSemanticCorrector { +public class GlobalCorrector extends BaseSemanticCorrector { @Override public void correct(SemanticCorrectInfo semanticCorrectInfo) { + replaceAlias(semanticCorrectInfo); + + updateFieldNameByLinkingValue(semanticCorrectInfo); + + updateFieldNameByBizName(semanticCorrectInfo); + + addAggregateToMetric(semanticCorrectInfo); + } + + private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) { + + } + + private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) { + String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql()); + semanticCorrectInfo.setSql(replaceAlias); + } + + private void updateFieldNameByBizName(SemanticCorrectInfo semanticCorrectInfo) { + + Map fieldToBizName = getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId()); + + String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldToBizName); + + semanticCorrectInfo.setSql(sql); + } + + private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) { Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT); if (Objects.isNull(context)) { return; @@ -45,5 +73,4 @@ public class FieldNameCorrector extends BaseSemanticCorrector { String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames); semanticCorrectInfo.setSql(sql); } - -} +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java new file mode 100644 index 000000000..c931d2f0f --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java @@ -0,0 +1,15 @@ +package com.tencent.supersonic.chat.corrector; + +import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class GroupByCorrector extends BaseSemanticCorrector { + + @Override + public void correct(SemanticCorrectInfo semanticCorrectInfo) { + + + + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java new file mode 100644 index 000000000..c5d8a514d --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java @@ -0,0 +1,14 @@ +package com.tencent.supersonic.chat.corrector; + +import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class HavingCorrector extends BaseSemanticCorrector { + + @Override + public void correct(SemanticCorrectInfo semanticCorrectInfo) { + + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/QueryFilterAppend.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/QueryFilterAppend.java deleted file mode 100644 index 4bb63515d..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/QueryFilterAppend.java +++ /dev/null @@ -1,48 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; -import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.util.StringUtil; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; -import java.util.Objects; -import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.JSQLParserException; -import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.parser.CCJSqlParserUtil; -import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.lang3.StringUtils; - -@Slf4j -public class QueryFilterAppend extends BaseSemanticCorrector { - - @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException { - String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters()); - String preSql = semanticCorrectInfo.getSql(); - - if (StringUtils.isNotEmpty(queryFilter)) { - log.info("add queryFilter to preSql :{}", queryFilter); - Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter); - String sql = SqlParserUpdateHelper.addWhere(preSql, expression); - semanticCorrectInfo.setPreSql(preSql); - semanticCorrectInfo.setSql(sql); - } - } - - private String getQueryFilter(QueryFilters queryFilters) { - if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) { - return null; - } - return queryFilters.getFilters().stream() - .map(filter -> { - String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName()); - String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue()); - String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString()); - return bizNameWrap + operatorWrap + valueWrap; - }) - .collect(Collectors.joining(Constants.AND_UPPER)); - } - -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java similarity index 96% rename from chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrector.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java index 5476370fb..62407df25 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java @@ -13,11 +13,12 @@ import net.sf.jsqlparser.expression.Expression; import org.springframework.util.CollectionUtils; @Slf4j -public class SelectFieldAppendCorrector extends BaseSemanticCorrector { +public class SelectCorrector extends BaseSemanticCorrector { @Override public void correct(SemanticCorrectInfo semanticCorrectInfo) { String preSql = semanticCorrectInfo.getSql(); + if (SqlParserSelectHelper.hasAggregateFunction(preSql)) { Expression havingExpression = SqlParserSelectHelper.getHavingExpression(preSql); if (Objects.nonNull(havingExpression)) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/TableNameCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/TableCorrector.java similarity index 91% rename from chat/core/src/main/java/com/tencent/supersonic/chat/corrector/TableNameCorrector.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/corrector/TableCorrector.java index 03f9b7ecb..1a64727c3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/TableNameCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/TableCorrector.java @@ -5,7 +5,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import lombok.extern.slf4j.Slf4j; @Slf4j -public class TableNameCorrector extends BaseSemanticCorrector { +public class TableCorrector extends BaseSemanticCorrector { public static final String TABLE_PREFIX = "t_"; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldValueCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java similarity index 50% rename from chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldValueCorrector.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java index b660f8946..fc607194a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldValueCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java @@ -1,26 +1,92 @@ package com.tencent.supersonic.chat.corrector; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaValueMap; +import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; +import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper; +import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.StringUtil; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.knowledge.service.SchemaService; +import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.util.Strings; import org.springframework.util.CollectionUtils; @Slf4j -public class FieldValueCorrector extends BaseSemanticCorrector { +public class WhereCorrector extends BaseSemanticCorrector { @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { + public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException { + + addDateIfNotExist(semanticCorrectInfo); + + parserDateDiffFunction(semanticCorrectInfo); + + addQueryFilter(semanticCorrectInfo); + + updateFieldValueByTechName(semanticCorrectInfo); + } + + private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException { + String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters()); + + String preSql = semanticCorrectInfo.getSql(); + + if (StringUtils.isNotEmpty(queryFilter)) { + log.info("add queryFilter to preSql :{}", queryFilter); + Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter); + String sql = SqlParserUpdateHelper.addWhere(preSql, expression); + semanticCorrectInfo.setPreSql(preSql); + semanticCorrectInfo.setSql(sql); + } + } + + private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) { + String preSql = semanticCorrectInfo.getSql(); + semanticCorrectInfo.setPreSql(preSql); + String sql = SqlParserUpdateHelper.replaceFunction(preSql); + semanticCorrectInfo.setSql(sql); + } + + private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) { + String sql = semanticCorrectInfo.getSql(); + List whereFields = SqlParserSelectHelper.getWhereFields(sql); + if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getName())) { + String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId()); + sql = SqlParserUpdateHelper.addWhere(sql, TimeDimensionEnum.DAY.getName(), currentDate); + } + semanticCorrectInfo.setSql(sql); + } + + private String getQueryFilter(QueryFilters queryFilters) { + if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) { + return null; + } + return queryFilters.getFilters().stream() + .map(filter -> { + String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName()); + String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue()); + String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString()); + return bizNameWrap + operatorWrap + valueWrap; + }) + .collect(Collectors.joining(Constants.AND_UPPER)); + } + + private void updateFieldValueByTechName(SemanticCorrectInfo semanticCorrectInfo) { SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId(); List dimensions = semanticSchema.getDimensions().stream() @@ -39,7 +105,6 @@ public class FieldValueCorrector extends BaseSemanticCorrector { return; } - private Map> getAliasAndBizNameToTechName(List dimensions) { if (CollectionUtils.isEmpty(dimensions)) { return new HashMap<>(); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java index f286c1177..9b6c24422 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java @@ -408,27 +408,20 @@ public class LLMDslParser implements SemanticParser { protected List getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema, LLMParserConfig llmParserConfig) { + + Set results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig); + + Set fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema); + + results.addAll(fieldNameList); + return new ArrayList<>(results); + } + + protected Set getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { Map itemIdToName = getItemIdToName(modelId, semanticSchema); - - Set results = semanticSchema.getDimensions().stream() - .filter(schemaElement -> modelId.equals(schemaElement.getModel())) - .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(llmParserConfig.getDimensionTopN()) - .map(entry -> entry.getName()) - .collect(Collectors.toSet()); - - Set metrics = semanticSchema.getMetrics().stream() - .filter(schemaElement -> modelId.equals(schemaElement.getModel())) - .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(llmParserConfig.getMetricTopN()) - .map(entry -> entry.getName()) - .collect(Collectors.toSet()); - - results.addAll(metrics); - List matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId); if (CollectionUtils.isEmpty(matchedElements)) { - return new ArrayList<>(results); + return new HashSet<>(); } Set fieldNameList = matchedElements.stream() .filter(schemaElementMatch -> { @@ -447,13 +440,29 @@ public class LLMDslParser implements SemanticParser { }) .filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%")) .collect(Collectors.toSet()); - results.addAll(fieldNameList); - return new ArrayList<>(results); + return fieldNameList; + } + + private Set getTopNFieldNames(Long modelId, SemanticSchema semanticSchema, + LLMParserConfig llmParserConfig) { + Set results = semanticSchema.getDimensions(modelId).stream() + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(llmParserConfig.getDimensionTopN()) + .map(entry -> entry.getName()) + .collect(Collectors.toSet()); + + Set metrics = semanticSchema.getMetrics(modelId).stream() + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(llmParserConfig.getMetricTopN()) + .map(entry -> entry.getName()) + .collect(Collectors.toSet()); + + results.addAll(metrics); + return results; } protected Map getItemIdToName(Long modelId, SemanticSchema semanticSchema) { - return semanticSchema.getDimensions().stream() - .filter(entry -> modelId.equals(entry.getModel())) + return semanticSchema.getDimensions(modelId).stream() .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); } diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/DateFieldCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/DateFieldCorrectorTest.java deleted file mode 100644 index 4bc2a3919..000000000 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/DateFieldCorrectorTest.java +++ /dev/null @@ -1,45 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import static org.mockito.ArgumentMatchers.any; - -import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper; -import org.junit.Assert; -import org.junit.jupiter.api.Test; -import org.mockito.MockedStatic; -import org.mockito.Mockito; - -class DateFieldCorrectorTest { - - @Test - void corrector() { - MockedStatic dslDateHelper = Mockito.mockStatic(DSLDateHelper.class); - - dslDateHelper.when(() -> DSLDateHelper.getReferenceDate(any())).thenReturn("2023-08-14"); - DateFieldCorrector dateFieldCorrector = new DateFieldCorrector(); - SemanticParseInfo parseInfo = new SemanticParseInfo(); - SchemaElement model = new SchemaElement(); - model.setId(2L); - parseInfo.setModel(model); - SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder() - .sql("select count(歌曲名) from 歌曲库 ") - .parseInfo(parseInfo) - .build(); - - dateFieldCorrector.correct(semanticCorrectInfo); - - Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql()); - - semanticCorrectInfo = SemanticCorrectInfo.builder() - .sql("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'") - .parseInfo(parseInfo) - .build(); - - dateFieldCorrector.correct(semanticCorrectInfo); - - Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql()); - - } -} diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldNameCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldNameCorrectorTest.java deleted file mode 100644 index 7caae3c06..000000000 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldNameCorrectorTest.java +++ /dev/null @@ -1,65 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult; -import com.tencent.supersonic.chat.query.llm.dsl.LLMReq; -import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue; -import com.tencent.supersonic.common.pojo.Constants; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.Assert; -import org.junit.jupiter.api.Test; - -class FieldNameCorrectorTest { - - @Test - void corrector() { - - FieldNameCorrector corrector = new FieldNameCorrector(); - SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder() - .sql("select 歌曲名 from 歌曲库 where 专辑照片 = '七里香' and 专辑名 = '流行' and 数据日期 = '2023-08-19'") - .build(); - - SemanticParseInfo parseInfo = new SemanticParseInfo(); - - DSLParseResult dslParseResult = new DSLParseResult(); - LLMReq llmReq = new LLMReq(); - List linking = new ArrayList<>(); - ElementValue elementValue = new ElementValue(); - elementValue.setFieldValue("流行"); - elementValue.setFieldName("歌曲风格"); - linking.add(elementValue); - - ElementValue elementValue2 = new ElementValue(); - elementValue2.setFieldValue("七里香"); - elementValue2.setFieldName("歌曲名"); - linking.add(elementValue2); - - ElementValue elementValue3 = new ElementValue(); - elementValue3.setFieldValue("周杰伦"); - elementValue3.setFieldName("歌手名"); - linking.add(elementValue3); - - ElementValue elementValue4 = new ElementValue(); - elementValue4.setFieldValue("流行"); - elementValue4.setFieldName("歌曲流派"); - linking.add(elementValue4); - - llmReq.setLinking(linking); - dslParseResult.setLlmReq(llmReq); - - Map properties = new HashMap<>(); - properties.put(Constants.CONTEXT, dslParseResult); - - parseInfo.setProperties(properties); - semanticCorrectInfo.setParseInfo(parseInfo); - - corrector.correct(semanticCorrectInfo); - - Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE 歌曲名 = '七里香' AND 歌曲流派 = '流行' AND 数据日期 = '2023-08-19'", - semanticCorrectInfo.getSql()); - } -} \ No newline at end of file diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldValueCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldValueCorrectorTest.java deleted file mode 100644 index d9afccf23..000000000 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldValueCorrectorTest.java +++ /dev/null @@ -1,71 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import static org.mockito.Mockito.when; - -import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.SchemaValueMap; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticSchema; -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.knowledge.service.SchemaService; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import org.junit.Assert; -import org.junit.jupiter.api.Test; -import org.mockito.MockedStatic; -import org.mockito.Mockito; - -class FieldValueCorrectorTest { - - - @Test - void corrector() { - - MockedStatic mockContextUtils = Mockito.mockStatic(ContextUtils.class); - - SchemaService mockSchemaService = Mockito.mock(SchemaService.class); - - SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class); - - List dimensions = new ArrayList<>(); - List schemaValueMaps = new ArrayList<>(); - SchemaValueMap value1 = new SchemaValueMap(); - value1.setBizName("杰伦"); - value1.setTechName("周杰伦"); - value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生")); - schemaValueMaps.add(value1); - - SchemaElement schemaElement = SchemaElement.builder() - .bizName("singer_name") - .name("歌手名") - .model(2L) - .schemaValueMaps(schemaValueMaps) - .build(); - dimensions.add(schemaElement); - - when(mockSemanticSchema.getDimensions()).thenReturn(dimensions); - when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema); - mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService); - - SemanticParseInfo parseInfo = new SemanticParseInfo(); - SchemaElement model = new SchemaElement(); - model.setId(2L); - parseInfo.setModel(model); - SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder() - .sql("select count(song_name) from 歌曲库 where singer_name = '周先生'") - .parseInfo(parseInfo) - .build(); - - FieldValueCorrector corrector = new FieldValueCorrector(); - corrector.correct(semanticCorrectInfo); - - Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql()); - - semanticCorrectInfo.setSql("select count(song_name) from 歌曲库 where singer_name = '杰伦'"); - corrector.correct(semanticCorrectInfo); - - Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql()); - } -} \ No newline at end of file diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrectorTest.java deleted file mode 100644 index 39db3935d..000000000 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrectorTest.java +++ /dev/null @@ -1,46 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import org.junit.Assert; -import org.junit.jupiter.api.Test; - -class SelectFieldAppendCorrectorTest { - - @Test - void corrector() { - SelectFieldAppendCorrector corrector = new SelectFieldAppendCorrector(); - SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder() - .sql("select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' " - + "and sys_imp_date = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11") - .build(); - - corrector.correct(semanticCorrectInfo); - - Assert.assertEquals( - "SELECT 歌曲名, 歌手名, 播放量, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE " - + "datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '邓紫棋' " - + "AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01'" - + " ORDER BY 播放量 DESC LIMIT 11", semanticCorrectInfo.getSql()); - - semanticCorrectInfo.setSql("select 用户名 from 内容库产品 where datediff('day', 数据日期, '2023-09-14') <= 30" - + " group by 用户名 having sum(访问次数) > 2000"); - - corrector.correct(semanticCorrectInfo); - - Assert.assertEquals( - "SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE " - + "datediff('day', 数据日期, '2023-09-14') <= 30 " - + "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql()); - - semanticCorrectInfo.setSql("SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE " - + "datediff('day', 数据日期, '2023-09-14') <= 30 " - + "GROUP BY 用户名 HAVING sum(访问次数) > 2000"); - - corrector.correct(semanticCorrectInfo); - - Assert.assertEquals( - "SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE " - + "datediff('day', 数据日期, '2023-09-14') <= 30 " - + "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql()); - } -} diff --git a/launchers/chat/src/main/resources/META-INF/spring.factories b/launchers/chat/src/main/resources/META-INF/spring.factories index c760ce104..aa0a5ff20 100644 --- a/launchers/chat/src/main/resources/META-INF/spring.factories +++ b/launchers/chat/src/main/resources/META-INF/spring.factories @@ -31,12 +31,9 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.chat.api.component.SemanticCorrector=\ - com.tencent.supersonic.chat.corrector.DateFieldCorrector, \ - com.tencent.supersonic.chat.corrector.FunctionAliasCorrector, \ - com.tencent.supersonic.chat.corrector.FieldNameCorrector, \ - com.tencent.supersonic.chat.corrector.FieldCorrector, \ - com.tencent.supersonic.chat.corrector.FunctionCorrector, \ - com.tencent.supersonic.chat.corrector.TableNameCorrector, \ - com.tencent.supersonic.chat.corrector.QueryFilterAppend, \ - com.tencent.supersonic.chat.corrector.SelectFieldAppendCorrector, \ - com.tencent.supersonic.chat.corrector.FieldValueCorrector \ No newline at end of file + com.tencent.supersonic.chat.corrector.GlobalCorrector, \ + com.tencent.supersonic.chat.corrector.TableCorrector, \ + com.tencent.supersonic.chat.corrector.GroupByCorrector, \ + com.tencent.supersonic.chat.corrector.SelectCorrector, \ + com.tencent.supersonic.chat.corrector.WhereCorrector, \ + com.tencent.supersonic.chat.corrector.HavingCorrector \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index c0adbf1d6..3e8c7a5d6 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -31,12 +31,9 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor com.tencent.supersonic.chat.api.component.SemanticCorrector=\ - com.tencent.supersonic.chat.corrector.DateFieldCorrector, \ - com.tencent.supersonic.chat.corrector.FunctionAliasCorrector, \ - com.tencent.supersonic.chat.corrector.FieldNameCorrector, \ - com.tencent.supersonic.chat.corrector.FieldCorrector, \ - com.tencent.supersonic.chat.corrector.FunctionCorrector, \ - com.tencent.supersonic.chat.corrector.TableNameCorrector, \ - com.tencent.supersonic.chat.corrector.QueryFilterAppend, \ - com.tencent.supersonic.chat.corrector.SelectFieldAppendCorrector, \ - com.tencent.supersonic.chat.corrector.FieldValueCorrector + com.tencent.supersonic.chat.corrector.GlobalCorrector, \ + com.tencent.supersonic.chat.corrector.TableCorrector, \ + com.tencent.supersonic.chat.corrector.GroupByCorrector, \ + com.tencent.supersonic.chat.corrector.SelectCorrector, \ + com.tencent.supersonic.chat.corrector.WhereCorrector, \ + com.tencent.supersonic.chat.corrector.HavingCorrector