diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java index 2d9742fd3..e74a46550 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalBeforeCorrector.java @@ -27,6 +27,8 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { updateFieldNameByLinkingValue(semanticCorrectInfo); + updateFieldValueByLinkingValue(semanticCorrectInfo); + correctFieldName(semanticCorrectInfo); } @@ -45,17 +47,7 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { } private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) { - Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT); - if (Objects.isNull(context)) { - return; - } - - DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class); - if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) { - return; - } - LLMReq llmReq = dslParseResult.getLlmReq(); - List linking = llmReq.getLinking(); + List linking = getLinkingValues(semanticCorrectInfo); if (CollectionUtils.isEmpty(linking)) { return; } @@ -68,4 +60,37 @@ public class GlobalBeforeCorrector extends BaseSemanticCorrector { fieldValueToFieldNames); semanticCorrectInfo.setSql(sql); } + + private List getLinkingValues(SemanticCorrectInfo semanticCorrectInfo) { + Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT); + if (Objects.isNull(context)) { + return null; + } + + DSLParseResult dslParseResult = JsonUtil.toObject(JsonUtil.toString(context), DSLParseResult.class); + if (Objects.isNull(dslParseResult) || Objects.isNull(dslParseResult.getLlmReq())) { + return null; + } + LLMReq llmReq = dslParseResult.getLlmReq(); + return llmReq.getLinking(); + } + + + private void updateFieldValueByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) { + List linking = getLinkingValues(semanticCorrectInfo); + if (CollectionUtils.isEmpty(linking)) { + return; + } + + Map> filedNameToValueMap = linking.stream().collect( + Collectors.groupingBy(ElementValue::getFieldName, + Collectors.mapping(ElementValue::getFieldValue, Collectors.toMap( + oldValue -> oldValue, + newValue -> newValue, + (existingValue, newValue) -> newValue) + ))); + + String sql = SqlParserUpdateHelper.replaceValue(semanticCorrectInfo.getSql(), filedNameToValueMap, false); + semanticCorrectInfo.setSql(sql); + } } \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/entity/EntityListQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/entity/EntityListQuery.java index e1dce47ee..e5d16d454 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/entity/EntityListQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/entity/EntityListQuery.java @@ -1,9 +1,9 @@ package com.tencent.supersonic.chat.query.rule.entity; -import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.ModelSchema; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp; import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp; @@ -12,10 +12,10 @@ import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.util.ContextUtils; - import java.util.LinkedHashSet; import java.util.Objects; import java.util.Set; +import org.apache.commons.collections.CollectionUtils; public abstract class EntityListQuery extends EntitySemanticQuery { @@ -41,13 +41,17 @@ public abstract class EntityListQuery extends EntitySemanticQuery { ChatDefaultRichConfigResp chatDefaultConfig = chaConfigRichDesc .getChatDetailRichConfig().getChatDefaultConfig(); if (chatDefaultConfig != null) { - chatDefaultConfig.getMetrics().stream() - .forEach(metric -> { - metrics.add(metric); - orders.add(new Order(metric.getBizName(), Constants.DESC_UPPER)); - }); - chatDefaultConfig.getDimensions().stream() - .forEach(dimension -> dimensions.add(dimension)); + if (CollectionUtils.isNotEmpty(chatDefaultConfig.getMetrics())) { + chatDefaultConfig.getMetrics().stream() + .forEach(metric -> { + metrics.add(metric); + orders.add(new Order(metric.getBizName(), Constants.DESC_UPPER)); + }); + } + if (CollectionUtils.isNotEmpty(chatDefaultConfig.getDimensions())) { + chatDefaultConfig.getDimensions().stream() + .forEach(dimension -> dimensions.add(dimension)); + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java index 7c3c7a2fe..4659e0bcc 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/execute/EntityInfoExecuteResponder.java @@ -7,7 +7,6 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import java.util.List; import java.util.Map; import java.util.Objects; @@ -43,15 +42,6 @@ public class EntityInfoExecuteResponder implements ExecuteResponder { if (CollectionUtils.isEmpty(entities)) { return; } - QueryResultWithSchemaResp queryResultWithSchemaResp = semanticService.getQueryResultWithSchemaResp(entityInfo, - semanticParseInfo.getModelId(), entities, user); - if (Objects.isNull(queryResultWithSchemaResp)) { - return; - } - List> entityResultList = queryResultWithSchemaResp.getResultList(); - if (CollectionUtils.isEmpty(entityResultList)) { - return; - } } } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java index f78f717fe..7396d3ce1 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldlValueReplaceVisitor.java @@ -21,43 +21,19 @@ import org.springframework.util.CollectionUtils; public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { + ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); + private boolean exactReplace; private Map> filedNameToValueMap; - public FieldlValueReplaceVisitor(Map> filedNameToValueMap) { + + public FieldlValueReplaceVisitor(boolean exactReplace, Map> filedNameToValueMap) { + this.exactReplace = exactReplace; this.filedNameToValueMap = filedNameToValueMap; } @Override public void visit(EqualsTo expr) { - Expression leftExpression = expr.getLeftExpression(); - Expression rightExpression = expr.getRightExpression(); - if (!(rightExpression instanceof StringValue)) { - return; - } - if (!(leftExpression instanceof Column)) { - return; - } - if (CollectionUtils.isEmpty(filedNameToValueMap)) { - return; - } - if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) { - return; - } - Column leftColumnName = (Column) leftExpression; - StringValue rightStringValue = (StringValue) rightExpression; - - String columnName = leftColumnName.getColumnName(); - if (StringUtils.isEmpty(columnName)) { - return; - } - Map valueMap = filedNameToValueMap.get(columnName); - if (Objects.isNull(valueMap) || valueMap.isEmpty()) { - return; - } - String replaceValue = valueMap.get(rightStringValue.getValue()); - if (StringUtils.isNotEmpty(replaceValue)) { - rightStringValue.setValue(replaceValue); - } + replaceComparisonExpression(expr); } public void visit(GreaterThan expr) { @@ -77,47 +53,57 @@ public class FieldlValueReplaceVisitor extends ExpressionVisitorAdapter { } public void replaceComparisonExpression(T expression) { - if ((expression instanceof GreaterThanEquals) || (expression instanceof GreaterThan) - || (expression instanceof MinorThanEquals) || (expression instanceof MinorThan)) { - Expression leftExpression = ((ComparisonOperator) expression).getLeftExpression(); - Expression rightExpression = ((ComparisonOperator) expression).getRightExpression(); - if (!(leftExpression instanceof Column)) { - return; - } - if (CollectionUtils.isEmpty(filedNameToValueMap)) { - return; - } - if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) { - return; - } - Column leftColumnName = (Column) leftExpression; + Expression leftExpression = ((ComparisonOperator) expression).getLeftExpression(); + Expression rightExpression = ((ComparisonOperator) expression).getRightExpression(); + if (!(leftExpression instanceof Column)) { + return; + } + if (CollectionUtils.isEmpty(filedNameToValueMap)) { + return; + } + if (Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) { + return; + } + Column leftColumnName = (Column) leftExpression; - String columnName = leftColumnName.getColumnName(); - if (StringUtils.isEmpty(columnName)) { - return; - } - Map valueMap = filedNameToValueMap.get(columnName); - if (Objects.isNull(valueMap) || valueMap.isEmpty()) { - return; - } - for (String oriValue : valueMap.keySet()) { - String replaceValue = valueMap.get(oriValue); - if (StringUtils.isNotEmpty(replaceValue)) { - if (rightExpression instanceof LongValue) { - LongValue rightStringValue = (LongValue) rightExpression; - rightStringValue.setValue(Long.parseLong(replaceValue)); - } - if (rightExpression instanceof DoubleValue) { - DoubleValue rightStringValue = (DoubleValue) rightExpression; - rightStringValue.setValue(Double.parseDouble(replaceValue)); - } - if (rightExpression instanceof StringValue) { - StringValue rightStringValue = (StringValue) rightExpression; - rightStringValue.setValue(replaceValue); - } - } - } + String columnName = leftColumnName.getColumnName(); + if (StringUtils.isEmpty(columnName)) { + return; + } + Map valueMap = filedNameToValueMap.get(columnName); + if (Objects.isNull(valueMap) || valueMap.isEmpty()) { + return; + } + + if (rightExpression instanceof LongValue) { + LongValue rightStringValue = (LongValue) rightExpression; + String replaceValue = getReplaceValue(valueMap, String.valueOf(rightStringValue.getValue())); + if (StringUtils.isNotEmpty(replaceValue)) { + rightStringValue.setValue(Long.parseLong(replaceValue)); + } + } + if (rightExpression instanceof DoubleValue) { + DoubleValue rightStringValue = (DoubleValue) rightExpression; + String replaceValue = getReplaceValue(valueMap, String.valueOf(rightStringValue.getValue())); + if (StringUtils.isNotEmpty(replaceValue)) { + rightStringValue.setValue(Double.parseDouble(replaceValue)); + } + } + if (rightExpression instanceof StringValue) { + StringValue rightStringValue = (StringValue) rightExpression; + String replaceValue = getReplaceValue(valueMap, String.valueOf(rightStringValue.getValue())); + if (StringUtils.isNotEmpty(replaceValue)) { + rightStringValue.setValue(replaceValue); + } } } + + private String getReplaceValue(Map valueMap, String beforeValue) { + String afterValue = valueMap.get(String.valueOf(beforeValue)); + if (StringUtils.isEmpty(afterValue) && !exactReplace) { + return parseVisitorHelper.getReplaceValue(beforeValue, valueMap, false); + } + return afterValue; + } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java index eb46506fc..588b9e703 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java @@ -38,7 +38,7 @@ public class GroupByReplaceVisitor implements GroupByVisitor { ((Function) expression).getParameters().getExpressions().get(0))) { columnName = ((Function) expression).getParameters().getExpressions().get(0).toString(); } - String replaceColumn = parseVisitorHelper.getReplaceColumn(columnName, fieldNameMap, + String replaceColumn = parseVisitorHelper.getReplaceValue(columnName, fieldNameMap, exactReplace); if (StringUtils.isNotEmpty(replaceColumn)) { if (expression instanceof Column) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java index f4a5c63e7..22c91911f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java @@ -13,32 +13,32 @@ public class ParseVisitorHelper { public void replaceColumn(Column column, Map fieldNameMap, boolean exactReplace) { String columnName = column.getColumnName(); - String replaceColumn = getReplaceColumn(columnName, fieldNameMap, exactReplace); + String replaceColumn = getReplaceValue(columnName, fieldNameMap, exactReplace); if (StringUtils.isNotBlank(replaceColumn)) { column.setColumnName(replaceColumn); } } - public String getReplaceColumn(String columnName, Map fieldNameMap, boolean exactReplace) { - String fieldName = fieldNameMap.get(columnName); - if (StringUtils.isNotBlank(fieldName)) { - return fieldName; + public String getReplaceValue(String beforeValue, Map valueMap, boolean exactReplace) { + String value = valueMap.get(beforeValue); + if (StringUtils.isNotBlank(value)) { + return value; } if (exactReplace) { return null; } - Optional> first = fieldNameMap.entrySet().stream().sorted((k1, k2) -> { - String k1FieldNameDb = k1.getKey(); - String k2FieldNameDb = k2.getKey(); - Double k1Similarity = getSimilarity(columnName, k1FieldNameDb); - Double k2Similarity = getSimilarity(columnName, k2FieldNameDb); + Optional> first = valueMap.entrySet().stream().sorted((k1, k2) -> { + String k1Value = k1.getKey(); + String k2Value = k2.getKey(); + Double k1Similarity = getSimilarity(beforeValue, k1Value); + Double k2Similarity = getSimilarity(beforeValue, k2Value); return k2Similarity.compareTo(k1Similarity); }).collect(Collectors.toList()).stream().findFirst(); if (first.isPresent()) { return first.get().getValue(); } - return columnName; + return beforeValue; } public static int editDistance(String word1, String word2) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java index 94a474abe..fd3db6fe9 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelper.java @@ -39,6 +39,11 @@ import org.springframework.util.CollectionUtils; public class SqlParserUpdateHelper { public static String replaceValue(String sql, Map> filedNameToValueMap) { + return replaceValue(sql, filedNameToValueMap, true); + } + + public static String replaceValue(String sql, Map> filedNameToValueMap, + boolean exactReplace) { Select selectStatement = SqlParserSelectHelper.getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody(); if (!(selectBody instanceof PlainSelect)) { @@ -46,7 +51,7 @@ public class SqlParserUpdateHelper { } PlainSelect plainSelect = (PlainSelect) selectBody; Expression where = plainSelect.getWhere(); - FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(filedNameToValueMap); + FieldlValueReplaceVisitor visitor = new FieldlValueReplaceVisitor(exactReplace, filedNameToValueMap); if (Objects.nonNull(where)) { where.accept(visitor); } diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java index d1f9dcbcd..355d21157 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserUpdateHelperTest.java @@ -18,6 +18,47 @@ import org.junit.jupiter.api.Test; */ class SqlParserUpdateHelperTest { + @Test + void replaceValue() { + + String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '杰伦' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + Map> filedNameToValueMap = new HashMap<>(); + + Map valueMap = new HashMap<>(); + valueMap.put("杰伦", "周杰伦"); + filedNameToValueMap.put("歌手名", valueMap); + + replaceSql = SqlParserUpdateHelper.replaceValue(replaceSql, filedNameToValueMap); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND " + + "歌手名 = '周杰伦' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " + + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); + + replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; + + Map> filedNameToValueMap2 = new HashMap<>(); + + Map valueMap2 = new HashMap<>(); + valueMap2.put("周杰伦", "周杰伦"); + valueMap2.put("林俊杰", "林俊杰"); + valueMap2.put("陈奕迅", "陈奕迅"); + filedNameToValueMap2.put("歌手名", valueMap2); + + replaceSql = SqlParserUpdateHelper.replaceValue(replaceSql, filedNameToValueMap2, false); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '周杰伦' " + + "AND 歌手名 = '林俊杰' AND 歌手名 = '陈奕迅' AND 数据日期 = '2023-08-09' AND " + + "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); + + } + @Test void replaceFieldNameByValue() {