diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldReplaceVisitor.java index ccd0a64da..5d46c6b65 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldReplaceVisitor.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.common.jsqlparser; -import com.tencent.supersonic.common.util.ContextUtils; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.AnalyticExpression; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; @@ -14,6 +13,7 @@ import java.util.Map; @Slf4j public class FieldReplaceVisitor extends ExpressionVisitorAdapter { + private Map fieldNameMap; private ThreadLocal exactReplace = ThreadLocal.withInitial(() -> false); @@ -24,8 +24,7 @@ public class FieldReplaceVisitor extends ExpressionVisitorAdapter { @Override public void visit(Column column) { - ReplaceService replaceService = ContextUtils.getBean(ReplaceService.class); - replaceService.replaceColumn(column, fieldNameMap, exactReplace.get()); + SqlReplaceHelper.replaceColumn(column, fieldNameMap, exactReplace.get()); } @Override diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldValueReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldValueReplaceVisitor.java index 65b669c88..d1fd30ef9 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldValueReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldValueReplaceVisitor.java @@ -1,21 +1,8 @@ package com.tencent.supersonic.common.jsqlparser; -import com.tencent.supersonic.common.util.ContextUtils; import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.expression.DoubleValue; -import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; -import net.sf.jsqlparser.expression.Function; -import net.sf.jsqlparser.expression.LongValue; -import net.sf.jsqlparser.expression.StringValue; -import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; -import net.sf.jsqlparser.expression.operators.relational.EqualsTo; -import net.sf.jsqlparser.expression.operators.relational.ExpressionList; -import net.sf.jsqlparser.expression.operators.relational.GreaterThan; -import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; -import net.sf.jsqlparser.expression.operators.relational.InExpression; -import net.sf.jsqlparser.expression.operators.relational.MinorThan; -import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; +import net.sf.jsqlparser.expression.*; +import net.sf.jsqlparser.expression.operators.relational.*; import net.sf.jsqlparser.schema.Column; import org.apache.commons.lang3.StringUtils; import org.springframework.util.CollectionUtils; @@ -28,13 +15,14 @@ import java.util.Objects; @Slf4j public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter { - private boolean exactReplace; - private Map> filedNameToValueMap; + private final boolean exactReplace; + + private final Map> filedNameToValueMap; public FieldValueReplaceVisitor(boolean exactReplace, Map> filedNameToValueMap) { - this.exactReplace = exactReplace; this.filedNameToValueMap = filedNameToValueMap; + this.exactReplace = exactReplace; } @Override @@ -137,9 +125,8 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter { private String getReplaceValue(Map valueMap, String beforeValue) { String afterValue = valueMap.get(String.valueOf(beforeValue)); - if (StringUtils.isEmpty(afterValue) && !exactReplace) { - ReplaceService replaceService = ContextUtils.getBean(ReplaceService.class); - return replaceService.getReplaceValue(beforeValue, valueMap, false); + if (StringUtils.isEmpty(afterValue)) { + return SqlReplaceHelper.getReplaceValue(beforeValue, valueMap, exactReplace); } return afterValue; } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByReplaceVisitor.java index b46c846b4..b521998de 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByReplaceVisitor.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.common.jsqlparser; -import com.tencent.supersonic.common.util.ContextUtils; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; @@ -15,8 +14,8 @@ import java.util.Map; @Slf4j public class GroupByReplaceVisitor implements GroupByVisitor { - private Map fieldNameMap; - private boolean exactReplace; + private final boolean exactReplace; + private final Map fieldNameMap; public GroupByReplaceVisitor(Map fieldNameMap, boolean exactReplace) { this.fieldNameMap = fieldNameMap; @@ -34,11 +33,10 @@ public class GroupByReplaceVisitor implements GroupByVisitor { } private void replaceExpression(Expression expression) { - ReplaceService replaceService = ContextUtils.getBean(ReplaceService.class); if (expression instanceof Column) { - replaceService.replaceColumn((Column) expression, fieldNameMap, exactReplace); + SqlReplaceHelper.replaceColumn((Column) expression, fieldNameMap, exactReplace); } else if (expression instanceof Function) { - replaceService.replaceFunction((Function) expression, fieldNameMap, exactReplace); + SqlReplaceHelper.replaceFunction((Function) expression, fieldNameMap, exactReplace); } } } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/OrderByReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/OrderByReplaceVisitor.java index 4c0ee2cf0..f3975596d 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/OrderByReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/OrderByReplaceVisitor.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.common.jsqlparser; -import com.tencent.supersonic.common.util.ContextUtils; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.schema.Column; @@ -10,8 +9,9 @@ import net.sf.jsqlparser.statement.select.OrderByVisitorAdapter; import java.util.Map; public class OrderByReplaceVisitor extends OrderByVisitorAdapter { - private Map fieldNameMap; - private boolean exactReplace; + + private final boolean exactReplace; + private final Map fieldNameMap; public OrderByReplaceVisitor(Map fieldNameMap, boolean exactReplace) { this.fieldNameMap = fieldNameMap; @@ -21,12 +21,11 @@ public class OrderByReplaceVisitor extends OrderByVisitorAdapter { @Override public void visit(OrderByElement orderBy) { Expression expression = orderBy.getExpression(); - ReplaceService replaceService = ContextUtils.getBean(ReplaceService.class); if (expression instanceof Column) { - replaceService.replaceColumn((Column) expression, fieldNameMap, exactReplace); + SqlReplaceHelper.replaceColumn((Column) expression, fieldNameMap, exactReplace); } if (expression instanceof Function) { - replaceService.replaceFunction((Function) expression, fieldNameMap, exactReplace); + SqlReplaceHelper.replaceFunction((Function) expression, fieldNameMap, exactReplace); } super.visit(orderBy); } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/QueryExpressionReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/QueryExpressionReplaceVisitor.java index 6e2a13d36..07e3b6666 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/QueryExpressionReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/QueryExpressionReplaceVisitor.java @@ -12,6 +12,7 @@ import java.util.Objects; public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter { private Map fieldExprMap; + private String lastColumnName; public QueryExpressionReplaceVisitor(Map fieldExprMap) { this.fieldExprMap = fieldExprMap; @@ -23,24 +24,26 @@ public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter { } public void visit(SelectItem selectExpressionItem) { - Expression expression = selectExpressionItem.getExpression(); String toReplace = ""; - String columnName = ""; if (expression instanceof Function) { Function leftFunc = (Function) expression; if (Objects.nonNull(leftFunc.getParameters()) && leftFunc.getParameters().getExpressions().get(0) instanceof Column) { Column column = (Column) leftFunc.getParameters().getExpressions().get(0); - columnName = column.getColumnName(); + lastColumnName = column.getColumnName(); toReplace = getReplaceExpr(leftFunc, fieldExprMap); } - } - if (expression instanceof Column) { + } else if (expression instanceof Column) { Column column = (Column) expression; - columnName = column.getColumnName(); + lastColumnName = column.getColumnName(); toReplace = getReplaceExpr((Column) expression, fieldExprMap); + } else if (expression instanceof BinaryExpression) { + BinaryExpression expr = (BinaryExpression) expression; + expr.setLeftExpression(replace(expr.getLeftExpression(), fieldExprMap)); + expr.setRightExpression(replace(expr.getRightExpression(), fieldExprMap)); } + if (expression instanceof BinaryExpression) { BinaryExpression binaryExpression = (BinaryExpression) expression; visitBinaryExpression(binaryExpression); @@ -51,7 +54,7 @@ public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter { if (Objects.nonNull(toReplaceExpr)) { selectExpressionItem.setExpression(toReplaceExpr); if (Objects.isNull(selectExpressionItem.getAlias())) { - selectExpressionItem.setAlias(new Alias(columnName, true)); + selectExpressionItem.setAlias(new Alias(lastColumnName, true)); } } } @@ -68,6 +71,18 @@ public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter { if (expression instanceof Column) { toReplace = getReplaceExpr((Column) expression, fieldExprMap); } + if (expression instanceof BinaryExpression) { + BinaryExpression binaryExpression = (BinaryExpression) expression; + binaryExpression + .setLeftExpression(replace(binaryExpression.getLeftExpression(), fieldExprMap)); + binaryExpression.setRightExpression( + replace(binaryExpression.getRightExpression(), fieldExprMap)); + } + if (expression instanceof Parenthesis) { + Parenthesis parenthesis = (Parenthesis) expression; + parenthesis.setExpression(replace(parenthesis.getExpression(), fieldExprMap)); + } + if (!toReplace.isEmpty()) { Expression replace = getExpression(toReplace); if (Objects.nonNull(replace)) { diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ReplaceService.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ReplaceService.java deleted file mode 100644 index d7070331f..000000000 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ReplaceService.java +++ /dev/null @@ -1,74 +0,0 @@ -package com.tencent.supersonic.common.jsqlparser; - -import com.tencent.supersonic.common.util.EditDistanceUtils; -import com.tencent.supersonic.common.util.StringUtil; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.expression.Function; -import net.sf.jsqlparser.expression.operators.relational.ExpressionList; -import net.sf.jsqlparser.schema.Column; -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Service; - -import java.util.Map; -import java.util.Map.Entry; -import java.util.Optional; -import java.util.stream.Collectors; - -@Slf4j -@Service -@Data -public class ReplaceService { - - @Value("${s2.replace.threshold:0.4}") - private double replaceColumnThreshold; - - public void replaceFunction(Function expression, Map fieldNameMap, - boolean exactReplace) { - Function function = expression; - ExpressionList expressions = function.getParameters(); - for (Expression column : expressions) { - if (column instanceof Column) { - replaceColumn((Column) column, fieldNameMap, exactReplace); - } - } - } - - public void replaceColumn(Column column, Map fieldNameMap, - boolean exactReplace) { - String columnName = StringUtil.replaceBackticks(column.getColumnName()); - String replaceColumn = getReplaceValue(columnName, fieldNameMap, exactReplace); - if (StringUtils.isNotBlank(replaceColumn)) { - column.setColumnName(replaceColumn); - } - } - - public String getReplaceValue(String beforeValue, Map valueMap, - boolean exactReplace) { - String replaceValue = valueMap.get(beforeValue); - if (StringUtils.isNotBlank(replaceValue)) { - return replaceValue; - } - if (exactReplace) { - return null; - } - Optional> first = valueMap.entrySet().stream().sorted((k1, k2) -> { - String k1Value = k1.getKey(); - String k2Value = k2.getKey(); - Double k1Similarity = EditDistanceUtils.getSimilarity(beforeValue, k1Value); - Double k2Similarity = EditDistanceUtils.getSimilarity(beforeValue, k2Value); - return k2Similarity.compareTo(k1Similarity); - }).collect(Collectors.toList()).stream().findFirst(); - - if (first.isPresent()) { - replaceValue = first.get().getValue(); - double similarity = EditDistanceUtils.getSimilarity(beforeValue, replaceValue); - if (similarity > replaceColumnThreshold) { - return replaceValue; - } - } - return beforeValue; - } -} diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java index 91fbea3ec..0a5f7f562 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.common.jsqlparser; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import com.tencent.supersonic.common.util.EditDistanceUtils; import com.tencent.supersonic.common.util.StringUtil; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; @@ -25,21 +26,18 @@ import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.util.CollectionUtils; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; +import java.util.*; import java.util.function.UnaryOperator; +import java.util.stream.Collectors; /** * Sql Parser replace Helper */ @Slf4j public class SqlReplaceHelper { + + private final static double replaceColumnThreshold = 0.4; + public static String replaceAggFields(String sql, Map> fieldNameToAggMap) { Select selectStatement = SqlSelectHelper.getSelect(sql); @@ -769,4 +767,54 @@ public class SqlReplaceHelper { } } } + + public static void replaceFunction(Function expression, Map fieldNameMap, + boolean exactReplace) { + Function function = expression; + ExpressionList expressions = function.getParameters(); + for (Expression column : expressions) { + if (column instanceof Column) { + replaceColumn((Column) column, fieldNameMap, exactReplace); + } + } + } + + public static void replaceColumn(Column column, Map fieldNameMap, + boolean exactReplace) { + String columnName = StringUtil.replaceBackticks(column.getColumnName()); + String replaceColumn = getReplaceValue(columnName, fieldNameMap, exactReplace); + if (StringUtils.isNotBlank(replaceColumn)) { + log.debug("Replaced column {} to {}", column.getColumnName(), replaceColumn); + column.setColumnName(replaceColumn); + } + } + + public static String getReplaceValue(String beforeValue, Map valueMap, + boolean exactReplace) { + String replaceValue = valueMap.get(beforeValue); + if (StringUtils.isNotBlank(replaceValue)) { + return replaceValue; + } + if (exactReplace) { + return null; + } + Optional> first = + valueMap.entrySet().stream().sorted((k1, k2) -> { + String k1Value = k1.getKey(); + String k2Value = k2.getKey(); + Double k1Similarity = EditDistanceUtils.getSimilarity(beforeValue, k1Value); + Double k2Similarity = EditDistanceUtils.getSimilarity(beforeValue, k2Value); + return k2Similarity.compareTo(k1Similarity); + }).collect(Collectors.toList()).stream().findFirst(); + + if (first.isPresent()) { + replaceValue = first.get().getValue(); + double similarity = EditDistanceUtils.getSimilarity(beforeValue, replaceValue); + if (similarity > replaceColumnThreshold) { + return replaceValue; + } + } + return beforeValue; + } + } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java index d7b4085bd..02238b48b 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java @@ -459,6 +459,9 @@ public class SqlSelectHelper { .map(fieldExpression -> fieldExpression.getFieldName()).filter(Objects::nonNull) .collect(Collectors.toSet()); result.addAll(collect); + + Set aliases = getAliasFields(plainSelect); + result.removeAll(aliases); } public static List getOrderByExpressions(String sql) { diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java index 631ab59cf..16557d7cc 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java @@ -1,38 +1,16 @@ package com.tencent.supersonic.common.jsqlparser; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; -import com.tencent.supersonic.common.util.ContextUtils; import org.apache.commons.lang3.tuple.Pair; import org.junit.Assert; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.MockedStatic; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; - -import static org.mockito.Mockito.mockStatic; +import java.util.*; /** * SqlParserReplaceHelperTest */ class SqlReplaceHelperTest { - private MockedStatic mockedContextUtils; - - @BeforeEach - public void setUp() { - ReplaceService replaceService = new ReplaceService(); - replaceService.setReplaceColumnThreshold(0.4); - - // Mock the static method ContextUtils.getBean - mockedContextUtils = mockStatic(ContextUtils.class); - mockedContextUtils.when(() -> ContextUtils.getBean(ReplaceService.class)) - .thenReturn(replaceService); - } @Test void testReplaceAggField() { @@ -385,11 +363,4 @@ class SqlReplaceHelperTest { return fieldToBizName; } - @AfterEach - public void tearDown() { - // Close the mocked static context - if (mockedContextUtils != null) { - mockedContextUtils.close(); - } - } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java index 180f796ad..0eb354db6 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java @@ -113,7 +113,7 @@ public class SchemaCorrector extends BaseSemanticCorrector { sqlInfo.setCorrectedS2SQL(sql); } - public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext, + public void removeUnmappedFilterValue(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); String correctS2SQL = sqlInfo.getCorrectedS2SQL(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java index eeb1d48f2..87f00c615 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java @@ -3,12 +3,9 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlValidHelper; -import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.springframework.core.env.Environment; import org.springframework.util.CollectionUtils; import java.util.ArrayList; @@ -45,13 +42,6 @@ public class SelectCorrector extends BaseSemanticCorrector { Set selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL)); Set needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL)); - // decide whether add order by expression field to select - Environment environment = ContextUtils.getBean(Environment.class); - String correctorAdditionalInfo = environment.getProperty(ADDITIONAL_INFORMATION); - if (StringUtils.isNotBlank(correctorAdditionalInfo) - && Boolean.parseBoolean(correctorAdditionalInfo)) { - needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL)); - } if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) { return correctS2SQL; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java index f2e70e901..f9abfdaae 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java @@ -70,20 +70,24 @@ public class TimeRangeParser implements SemanticParser { } private DateConf parseDateCN(String queryText) { - List times = TimeNLPUtil.parse(queryText); - if (times.isEmpty()) { + try { + List times = TimeNLPUtil.parse(queryText); + if (times.isEmpty()) { + return null; + } + + Date startDate = times.get(0).getTime(); + String detectWord = times.get(0).getTimeExpression(); + Date endDate = times.size() > 1 ? times.get(1).getTime() : startDate; + + if (times.size() > 1) { + detectWord += "~" + times.get(1).getTimeExpression(); + } + + return getDateConf(startDate, endDate, detectWord); + } catch (Exception e) { return null; } - - Date startDate = times.get(0).getTime(); - String detectWord = times.get(0).getTimeExpression(); - Date endDate = times.size() > 1 ? times.get(1).getTime() : startDate; - - if (times.size() > 1) { - detectWord += "~" + times.get(1).getTimeExpression(); - } - - return getDateConf(startDate, endDate, detectWord); } private DateConf parseDateNumber(String queryText) { diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java index 1ec1de163..718b60fa6 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java @@ -3,17 +3,12 @@ package com.tencent.supersonic.headless.chat.corrector; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.headless.api.pojo.DataSetSchema; -import com.tencent.supersonic.headless.api.pojo.QueryConfig; -import com.tencent.supersonic.headless.api.pojo.SchemaElement; -import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.api.pojo.SemanticSchema; -import com.tencent.supersonic.headless.api.pojo.SqlInfo; +import com.tencent.supersonic.headless.api.pojo.*; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.parser.llm.ParseResult; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; +import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import org.junit.Assert; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import java.util.ArrayList; @@ -21,7 +16,6 @@ import java.util.HashSet; import java.util.List; import java.util.Set; -@Disabled class SchemaCorrectorTest { private String json = "{\n" + " \"dataSetId\": 1,\n" + " \"llmReq\": {\n" @@ -40,52 +34,54 @@ class SchemaCorrectorTest { + " },\n" + " \"request\": null\n" + "}"; @Test - void doCorrect() throws JsonProcessingException { - Long dataSetId = 1L; - ChatQueryContext chatQueryContext = buildQueryContext(dataSetId); + void testCorrectWrongColumnName() { + String sql = "SELECT 歌曲 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放) DESC LIMIT 10"; + ChatQueryContext chatQueryContext = buildQueryContext(sql); + SemanticParseInfo parseInfo = chatQueryContext.getCandidateQueries().get(0).getParseInfo(); + + SchemaCorrector schemaCorrector = new SchemaCorrector(); + schemaCorrector.correct(chatQueryContext, parseInfo); + + Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放量) DESC LIMIT 10", + parseInfo.getSqlInfo().getCorrectedS2SQL()); + } + + @Test + void testRemoveUnmappedFilterValue() throws JsonProcessingException { + String sql = "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10"; + ChatQueryContext chatQueryContext = buildQueryContext(sql); + SemanticParseInfo parseInfo = chatQueryContext.getCandidateQueries().get(0).getParseInfo(); + ObjectMapper objectMapper = new ObjectMapper(); ParseResult parseResult = objectMapper.readValue(json, ParseResult.class); - String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' " - + "and 商务组 = 'xxx' order by 播放量 desc limit 10"; - SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); - SqlInfo sqlInfo = new SqlInfo(); - sqlInfo.setParsedS2SQL(sql); - sqlInfo.setCorrectedS2SQL(sql); - semanticParseInfo.setSqlInfo(sqlInfo); - - SchemaElement schemaElement = new SchemaElement(); - schemaElement.setDataSetId(dataSetId); - semanticParseInfo.setDataSet(schemaElement); - - semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult); + parseInfo.getProperties().put(Constants.CONTEXT, parseResult); SchemaCorrector schemaCorrector = new SchemaCorrector(); - schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo); + schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo); - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + "ORDER BY 播放量 DESC LIMIT 10", - semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); - - parseResult = objectMapper.readValue(json, ParseResult.class); + Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY 播放量 DESC LIMIT 10", + parseInfo.getSqlInfo().getCorrectedS2SQL()); List linkingValues = new ArrayList<>(); LLMReq.ElementValue elementValue = new LLMReq.ElementValue(); elementValue.setFieldName("商务组"); elementValue.setFieldValue("xxx"); linkingValues.add(elementValue); - semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult); + parseResult.getLlmReq().getSchema().setValues(linkingValues); + parseInfo.getProperties().put(Constants.CONTEXT, parseResult); - semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql); - semanticParseInfo.getSqlInfo().setParsedS2SQL(sql); - schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo); - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + parseInfo.getSqlInfo().setCorrectedS2SQL(sql); + parseInfo.getSqlInfo().setParsedS2SQL(sql); + schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo); + Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", - semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); + parseInfo.getSqlInfo().getCorrectedS2SQL()); } - private ChatQueryContext buildQueryContext(Long dataSetId) { + private ChatQueryContext buildQueryContext(String sql) { + Long dataSetId = 1L; + ChatQueryContext chatQueryContext = new ChatQueryContext(); List dataSetSchemaList = new ArrayList<>(); DataSetSchema dataSetSchema = new DataSetSchema(); @@ -94,27 +90,29 @@ class SchemaCorrectorTest { SchemaElement schemaElement = new SchemaElement(); schemaElement.setDataSetId(dataSetId); dataSetSchema.setDataSet(schemaElement); + Set dimensions = new HashSet<>(); - SchemaElement element1 = new SchemaElement(); - element1.setDataSetId(1L); - element1.setName("歌曲名"); - dimensions.add(element1); - - SchemaElement element2 = new SchemaElement(); - element2.setDataSetId(1L); - element2.setName("商务组"); - dimensions.add(element2); - - SchemaElement element3 = new SchemaElement(); - element3.setDataSetId(1L); - element3.setName("发行日期"); - dimensions.add(element3); - + dimensions.add(SchemaElement.builder().name("歌曲名").dataSetId(dataSetId).build()); + dimensions.add(SchemaElement.builder().name("商务组").dataSetId(dataSetId).build()); + dimensions.add(SchemaElement.builder().name("发行日期").dataSetId(dataSetId).build()); + dimensions.add(SchemaElement.builder().name("播放量").dataSetId(dataSetId).build()); dataSetSchema.setDimensions(dimensions); dataSetSchemaList.add(dataSetSchema); SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList); chatQueryContext.setSemanticSchema(semanticSchema); + + SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); + SqlInfo sqlInfo = new SqlInfo(); + sqlInfo.setParsedS2SQL(sql); + sqlInfo.setCorrectedS2SQL(sql); + semanticParseInfo.setSqlInfo(sqlInfo); + semanticParseInfo.setDataSet(dataSetSchema.getDataSet()); + LLMSqlQuery sqlQuery = new LLMSqlQuery(); + sqlQuery.setParseInfo(semanticParseInfo); + chatQueryContext.getCandidateQueries().add(sqlQuery); + return chatQueryContext; } + } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/DefaultQueryCache.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/DefaultQueryCache.java index 4503dce46..e5cfa0a51 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/DefaultQueryCache.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/DefaultQueryCache.java @@ -20,7 +20,7 @@ public class DefaultQueryCache implements QueryCache { if (isCache(semanticQueryReq)) { Object result = cacheManager.get(cacheKey); if (Objects.nonNull(result)) { - log.info("query from cache, key:{},result:{}", cacheKey, + log.debug("query from cache, key:{},result:{}", cacheKey, StringUtils.normalizeSpace(result.toString())); } return result; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java index f76ab1df8..ebbfbffbb 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java @@ -45,7 +45,7 @@ public class JdbcExecutor implements QueryExecutor { sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns); queryResultWithColumns.setSql(sql); } catch (Exception e) { - log.error("queryInternal with error [{}]", StringUtils.normalizeSpace(e.getMessage())); + log.error("queryInternal with error ", e); queryResultWithColumns.setErrorMsg(e.getMessage()); } return queryResultWithColumns; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java index ba3ba93bd..f61d2f8bb 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java @@ -50,7 +50,8 @@ public class DefaultSemanticTranslator implements SemanticTranslator { optimizer.rewrite(queryStatement); } } - log.info("translated query SQL: [{}]", StringUtils.normalizeSpace(queryStatement.getSql())); + log.debug("translated query SQL: [{}]", + StringUtils.normalizeSpace(queryStatement.getSql())); } private void mergeOntologyQuery(QueryStatement queryStatement) throws Exception { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java index 5b96cc62e..d337d20aa 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java @@ -288,8 +288,6 @@ public class S2SemanticLayerService implements SemanticLayerService { queryStatement.setSql(semanticQueryReq.getSqlInfo().getQuerySQL()); queryStatement.setIsTranslated(true); } - queryStatement.setDataSetId(semanticQueryReq.getDataSetId()); - queryStatement.setDataSetName(semanticQueryReq.getDataSetName()); return queryStatement; } @@ -321,6 +319,11 @@ public class S2SemanticLayerService implements SemanticLayerService { Long dataSetId = dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user); querySqlReq.setDataSetId(dataSetId); } + if (querySqlReq.getDataSetId() != null) { + DataSetResp dataSetResp = dataSetService.getDataSet(querySqlReq.getDataSetId()); + queryStatement.setDataSetId(dataSetResp.getId()); + queryStatement.setDataSetName(dataSetResp.getName()); + } return queryStatement; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DictUtils.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DictUtils.java index 3dc313cc3..838f9a085 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DictUtils.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DictUtils.java @@ -273,9 +273,10 @@ public class DictUtils { private QuerySqlReq constructQuerySqlReq(DictItemResp dictItemResp) { + ModelResp model = modelService.getModel(dictItemResp.getModelId()); String sqlPattern = - "select %s,count(1) from tbl %s group by %s order by count(1) desc limit %d"; - String bizName = dictItemResp.getBizName(); + "select %s,count(1) from %s %s group by %s order by count(1) desc limit %d"; + String dimBizName = dictItemResp.getBizName(); String whereStr = generateWhereStr(dictItemResp); String where = StringUtils.isEmpty(whereStr) ? "" : "WHERE" + whereStr; ItemValueConfig config = dictItemResp.getConfig(); @@ -286,7 +287,8 @@ public class DictUtils { limit = Integer.MAX_VALUE; } - String sql = String.format(sqlPattern, bizName, where, bizName, limit); + String sql = + String.format(sqlPattern, dimBizName, model.getBizName(), where, dimBizName, limit); Set modelIds = new HashSet<>(); modelIds.add(dictItemResp.getModelId()); QuerySqlReq querySqlReq = new QuerySqlReq(); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryBySqlTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryBySqlTest.java index 647956f27..a12daa7ea 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryBySqlTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryBySqlTest.java @@ -17,6 +17,7 @@ import static org.junit.Assert.assertTrue; public class QueryBySqlTest extends BaseTest { @Test + @SetSystemProperty(key = "s2.test", value = "true") public void testDetailQuery() throws Exception { SemanticQueryResp semanticQueryResp = queryBySql("SELECT 用户名,访问次数 FROM 超音数PVUV统计 WHERE 用户名='alice' ");