[fix](headless)Fix a number of issues. (#2026)
Some checks are pending
supersonic CentOS CI / build (21) (push) Waiting to run
supersonic mac CI / build (21) (push) Waiting to run
supersonic ubuntu CI / build (21) (push) Waiting to run
supersonic windows CI / build (21) (push) Waiting to run

This commit is contained in:
Jun Zhang
2025-02-02 12:50:29 +08:00
committed by GitHub
parent de92b357df
commit d294fec2a0
19 changed files with 184 additions and 239 deletions

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.common.jsqlparser; package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.AnalyticExpression; import net.sf.jsqlparser.expression.AnalyticExpression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
@@ -14,6 +13,7 @@ import java.util.Map;
@Slf4j @Slf4j
public class FieldReplaceVisitor extends ExpressionVisitorAdapter { public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
private Map<String, String> fieldNameMap; private Map<String, String> fieldNameMap;
private ThreadLocal<Boolean> exactReplace = ThreadLocal.withInitial(() -> false); private ThreadLocal<Boolean> exactReplace = ThreadLocal.withInitial(() -> false);
@@ -24,8 +24,7 @@ public class FieldReplaceVisitor extends ExpressionVisitorAdapter {
@Override @Override
public void visit(Column column) { public void visit(Column column) {
ReplaceService replaceService = ContextUtils.getBean(ReplaceService.class); SqlReplaceHelper.replaceColumn(column, fieldNameMap, exactReplace.get());
replaceService.replaceColumn(column, fieldNameMap, exactReplace.get());
} }
@Override @Override

View File

@@ -1,21 +1,8 @@
package com.tencent.supersonic.common.jsqlparser; package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.DoubleValue; import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.operators.relational.*;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.GreaterThan;
import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@@ -28,13 +15,14 @@ import java.util.Objects;
@Slf4j @Slf4j
public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter { public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
private boolean exactReplace; private final boolean exactReplace;
private Map<String, Map<String, String>> filedNameToValueMap;
private final Map<String, Map<String, String>> filedNameToValueMap;
public FieldValueReplaceVisitor(boolean exactReplace, public FieldValueReplaceVisitor(boolean exactReplace,
Map<String, Map<String, String>> filedNameToValueMap) { Map<String, Map<String, String>> filedNameToValueMap) {
this.exactReplace = exactReplace;
this.filedNameToValueMap = filedNameToValueMap; this.filedNameToValueMap = filedNameToValueMap;
this.exactReplace = exactReplace;
} }
@Override @Override
@@ -137,9 +125,8 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
private String getReplaceValue(Map<String, String> valueMap, String beforeValue) { private String getReplaceValue(Map<String, String> valueMap, String beforeValue) {
String afterValue = valueMap.get(String.valueOf(beforeValue)); String afterValue = valueMap.get(String.valueOf(beforeValue));
if (StringUtils.isEmpty(afterValue) && !exactReplace) { if (StringUtils.isEmpty(afterValue)) {
ReplaceService replaceService = ContextUtils.getBean(ReplaceService.class); return SqlReplaceHelper.getReplaceValue(beforeValue, valueMap, exactReplace);
return replaceService.getReplaceValue(beforeValue, valueMap, false);
} }
return afterValue; return afterValue;
} }

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.common.jsqlparser; package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.Function;
@@ -15,8 +14,8 @@ import java.util.Map;
@Slf4j @Slf4j
public class GroupByReplaceVisitor implements GroupByVisitor { public class GroupByReplaceVisitor implements GroupByVisitor {
private Map<String, String> fieldNameMap; private final boolean exactReplace;
private boolean exactReplace; private final Map<String, String> fieldNameMap;
public GroupByReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) { public GroupByReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
this.fieldNameMap = fieldNameMap; this.fieldNameMap = fieldNameMap;
@@ -34,11 +33,10 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
} }
private void replaceExpression(Expression expression) { private void replaceExpression(Expression expression) {
ReplaceService replaceService = ContextUtils.getBean(ReplaceService.class);
if (expression instanceof Column) { if (expression instanceof Column) {
replaceService.replaceColumn((Column) expression, fieldNameMap, exactReplace); SqlReplaceHelper.replaceColumn((Column) expression, fieldNameMap, exactReplace);
} else if (expression instanceof Function) { } else if (expression instanceof Function) {
replaceService.replaceFunction((Function) expression, fieldNameMap, exactReplace); SqlReplaceHelper.replaceFunction((Function) expression, fieldNameMap, exactReplace);
} }
} }
} }

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.common.jsqlparser; package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.util.ContextUtils;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
@@ -10,8 +9,9 @@ import net.sf.jsqlparser.statement.select.OrderByVisitorAdapter;
import java.util.Map; import java.util.Map;
public class OrderByReplaceVisitor extends OrderByVisitorAdapter { public class OrderByReplaceVisitor extends OrderByVisitorAdapter {
private Map<String, String> fieldNameMap;
private boolean exactReplace; private final boolean exactReplace;
private final Map<String, String> fieldNameMap;
public OrderByReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) { public OrderByReplaceVisitor(Map<String, String> fieldNameMap, boolean exactReplace) {
this.fieldNameMap = fieldNameMap; this.fieldNameMap = fieldNameMap;
@@ -21,12 +21,11 @@ public class OrderByReplaceVisitor extends OrderByVisitorAdapter {
@Override @Override
public void visit(OrderByElement orderBy) { public void visit(OrderByElement orderBy) {
Expression expression = orderBy.getExpression(); Expression expression = orderBy.getExpression();
ReplaceService replaceService = ContextUtils.getBean(ReplaceService.class);
if (expression instanceof Column) { if (expression instanceof Column) {
replaceService.replaceColumn((Column) expression, fieldNameMap, exactReplace); SqlReplaceHelper.replaceColumn((Column) expression, fieldNameMap, exactReplace);
} }
if (expression instanceof Function) { if (expression instanceof Function) {
replaceService.replaceFunction((Function) expression, fieldNameMap, exactReplace); SqlReplaceHelper.replaceFunction((Function) expression, fieldNameMap, exactReplace);
} }
super.visit(orderBy); super.visit(orderBy);
} }

View File

@@ -12,6 +12,7 @@ import java.util.Objects;
public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter { public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter {
private Map<String, String> fieldExprMap; private Map<String, String> fieldExprMap;
private String lastColumnName;
public QueryExpressionReplaceVisitor(Map<String, String> fieldExprMap) { public QueryExpressionReplaceVisitor(Map<String, String> fieldExprMap) {
this.fieldExprMap = fieldExprMap; this.fieldExprMap = fieldExprMap;
@@ -23,24 +24,26 @@ public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter {
} }
public void visit(SelectItem selectExpressionItem) { public void visit(SelectItem selectExpressionItem) {
Expression expression = selectExpressionItem.getExpression(); Expression expression = selectExpressionItem.getExpression();
String toReplace = ""; String toReplace = "";
String columnName = "";
if (expression instanceof Function) { if (expression instanceof Function) {
Function leftFunc = (Function) expression; Function leftFunc = (Function) expression;
if (Objects.nonNull(leftFunc.getParameters()) if (Objects.nonNull(leftFunc.getParameters())
&& leftFunc.getParameters().getExpressions().get(0) instanceof Column) { && leftFunc.getParameters().getExpressions().get(0) instanceof Column) {
Column column = (Column) leftFunc.getParameters().getExpressions().get(0); Column column = (Column) leftFunc.getParameters().getExpressions().get(0);
columnName = column.getColumnName(); lastColumnName = column.getColumnName();
toReplace = getReplaceExpr(leftFunc, fieldExprMap); toReplace = getReplaceExpr(leftFunc, fieldExprMap);
} }
} } else if (expression instanceof Column) {
if (expression instanceof Column) {
Column column = (Column) expression; Column column = (Column) expression;
columnName = column.getColumnName(); lastColumnName = column.getColumnName();
toReplace = getReplaceExpr((Column) expression, fieldExprMap); toReplace = getReplaceExpr((Column) expression, fieldExprMap);
} else if (expression instanceof BinaryExpression) {
BinaryExpression expr = (BinaryExpression) expression;
expr.setLeftExpression(replace(expr.getLeftExpression(), fieldExprMap));
expr.setRightExpression(replace(expr.getRightExpression(), fieldExprMap));
} }
if (expression instanceof BinaryExpression) { if (expression instanceof BinaryExpression) {
BinaryExpression binaryExpression = (BinaryExpression) expression; BinaryExpression binaryExpression = (BinaryExpression) expression;
visitBinaryExpression(binaryExpression); visitBinaryExpression(binaryExpression);
@@ -51,7 +54,7 @@ public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter {
if (Objects.nonNull(toReplaceExpr)) { if (Objects.nonNull(toReplaceExpr)) {
selectExpressionItem.setExpression(toReplaceExpr); selectExpressionItem.setExpression(toReplaceExpr);
if (Objects.isNull(selectExpressionItem.getAlias())) { if (Objects.isNull(selectExpressionItem.getAlias())) {
selectExpressionItem.setAlias(new Alias(columnName, true)); selectExpressionItem.setAlias(new Alias(lastColumnName, true));
} }
} }
} }
@@ -68,6 +71,18 @@ public class QueryExpressionReplaceVisitor extends ExpressionVisitorAdapter {
if (expression instanceof Column) { if (expression instanceof Column) {
toReplace = getReplaceExpr((Column) expression, fieldExprMap); toReplace = getReplaceExpr((Column) expression, fieldExprMap);
} }
if (expression instanceof BinaryExpression) {
BinaryExpression binaryExpression = (BinaryExpression) expression;
binaryExpression
.setLeftExpression(replace(binaryExpression.getLeftExpression(), fieldExprMap));
binaryExpression.setRightExpression(
replace(binaryExpression.getRightExpression(), fieldExprMap));
}
if (expression instanceof Parenthesis) {
Parenthesis parenthesis = (Parenthesis) expression;
parenthesis.setExpression(replace(parenthesis.getExpression(), fieldExprMap));
}
if (!toReplace.isEmpty()) { if (!toReplace.isEmpty()) {
Expression replace = getExpression(toReplace); Expression replace = getExpression(toReplace);
if (Objects.nonNull(replace)) { if (Objects.nonNull(replace)) {

View File

@@ -1,74 +0,0 @@
package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.util.EditDistanceUtils;
import com.tencent.supersonic.common.util.StringUtil;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.stream.Collectors;
@Slf4j
@Service
@Data
public class ReplaceService {
@Value("${s2.replace.threshold:0.4}")
private double replaceColumnThreshold;
public void replaceFunction(Function expression, Map<String, String> fieldNameMap,
boolean exactReplace) {
Function function = expression;
ExpressionList<?> expressions = function.getParameters();
for (Expression column : expressions) {
if (column instanceof Column) {
replaceColumn((Column) column, fieldNameMap, exactReplace);
}
}
}
public void replaceColumn(Column column, Map<String, String> fieldNameMap,
boolean exactReplace) {
String columnName = StringUtil.replaceBackticks(column.getColumnName());
String replaceColumn = getReplaceValue(columnName, fieldNameMap, exactReplace);
if (StringUtils.isNotBlank(replaceColumn)) {
column.setColumnName(replaceColumn);
}
}
public String getReplaceValue(String beforeValue, Map<String, String> valueMap,
boolean exactReplace) {
String replaceValue = valueMap.get(beforeValue);
if (StringUtils.isNotBlank(replaceValue)) {
return replaceValue;
}
if (exactReplace) {
return null;
}
Optional<Entry<String, String>> first = valueMap.entrySet().stream().sorted((k1, k2) -> {
String k1Value = k1.getKey();
String k2Value = k2.getKey();
Double k1Similarity = EditDistanceUtils.getSimilarity(beforeValue, k1Value);
Double k2Similarity = EditDistanceUtils.getSimilarity(beforeValue, k2Value);
return k2Similarity.compareTo(k1Similarity);
}).collect(Collectors.toList()).stream().findFirst();
if (first.isPresent()) {
replaceValue = first.get().getValue();
double similarity = EditDistanceUtils.getSimilarity(beforeValue, replaceValue);
if (similarity > replaceColumnThreshold) {
return replaceValue;
}
}
return beforeValue;
}
}

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.common.jsqlparser; package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.EditDistanceUtils;
import com.tencent.supersonic.common.util.StringUtil; import com.tencent.supersonic.common.util.StringUtil;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.JSQLParserException;
@@ -25,21 +26,18 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.*;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.UnaryOperator; import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
/** /**
* Sql Parser replace Helper * Sql Parser replace Helper
*/ */
@Slf4j @Slf4j
public class SqlReplaceHelper { public class SqlReplaceHelper {
private final static double replaceColumnThreshold = 0.4;
public static String replaceAggFields(String sql, public static String replaceAggFields(String sql,
Map<String, Pair<String, String>> fieldNameToAggMap) { Map<String, Pair<String, String>> fieldNameToAggMap) {
Select selectStatement = SqlSelectHelper.getSelect(sql); Select selectStatement = SqlSelectHelper.getSelect(sql);
@@ -769,4 +767,54 @@ public class SqlReplaceHelper {
} }
} }
} }
public static void replaceFunction(Function expression, Map<String, String> fieldNameMap,
boolean exactReplace) {
Function function = expression;
ExpressionList<?> expressions = function.getParameters();
for (Expression column : expressions) {
if (column instanceof Column) {
replaceColumn((Column) column, fieldNameMap, exactReplace);
}
}
}
public static void replaceColumn(Column column, Map<String, String> fieldNameMap,
boolean exactReplace) {
String columnName = StringUtil.replaceBackticks(column.getColumnName());
String replaceColumn = getReplaceValue(columnName, fieldNameMap, exactReplace);
if (StringUtils.isNotBlank(replaceColumn)) {
log.debug("Replaced column {} to {}", column.getColumnName(), replaceColumn);
column.setColumnName(replaceColumn);
}
}
public static String getReplaceValue(String beforeValue, Map<String, String> valueMap,
boolean exactReplace) {
String replaceValue = valueMap.get(beforeValue);
if (StringUtils.isNotBlank(replaceValue)) {
return replaceValue;
}
if (exactReplace) {
return null;
}
Optional<Map.Entry<String, String>> first =
valueMap.entrySet().stream().sorted((k1, k2) -> {
String k1Value = k1.getKey();
String k2Value = k2.getKey();
Double k1Similarity = EditDistanceUtils.getSimilarity(beforeValue, k1Value);
Double k2Similarity = EditDistanceUtils.getSimilarity(beforeValue, k2Value);
return k2Similarity.compareTo(k1Similarity);
}).collect(Collectors.toList()).stream().findFirst();
if (first.isPresent()) {
replaceValue = first.get().getValue();
double similarity = EditDistanceUtils.getSimilarity(beforeValue, replaceValue);
if (similarity > replaceColumnThreshold) {
return replaceValue;
}
}
return beforeValue;
}
} }

View File

@@ -459,6 +459,9 @@ public class SqlSelectHelper {
.map(fieldExpression -> fieldExpression.getFieldName()).filter(Objects::nonNull) .map(fieldExpression -> fieldExpression.getFieldName()).filter(Objects::nonNull)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
result.addAll(collect); result.addAll(collect);
Set<String> aliases = getAliasFields(plainSelect);
result.removeAll(aliases);
} }
public static List<FieldExpression> getOrderByExpressions(String sql) { public static List<FieldExpression> getOrderByExpressions(String sql) {

View File

@@ -1,38 +1,16 @@
package com.tencent.supersonic.common.jsqlparser; package com.tencent.supersonic.common.jsqlparser;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert; import org.junit.Assert;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import java.util.Collections; import java.util.*;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import static org.mockito.Mockito.mockStatic;
/** /**
* SqlParserReplaceHelperTest * SqlParserReplaceHelperTest
*/ */
class SqlReplaceHelperTest { class SqlReplaceHelperTest {
private MockedStatic<ContextUtils> mockedContextUtils;
@BeforeEach
public void setUp() {
ReplaceService replaceService = new ReplaceService();
replaceService.setReplaceColumnThreshold(0.4);
// Mock the static method ContextUtils.getBean
mockedContextUtils = mockStatic(ContextUtils.class);
mockedContextUtils.when(() -> ContextUtils.getBean(ReplaceService.class))
.thenReturn(replaceService);
}
@Test @Test
void testReplaceAggField() { void testReplaceAggField() {
@@ -385,11 +363,4 @@ class SqlReplaceHelperTest {
return fieldToBizName; return fieldToBizName;
} }
@AfterEach
public void tearDown() {
// Close the mocked static context
if (mockedContextUtils != null) {
mockedContextUtils.close();
}
}
} }

View File

@@ -113,7 +113,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
sqlInfo.setCorrectedS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql);
} }
public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext, public void removeUnmappedFilterValue(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) { SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectedS2SQL(); String correctS2SQL = sqlInfo.getCorrectedS2SQL();

View File

@@ -3,12 +3,9 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper; import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.env.Environment;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
@@ -45,13 +42,6 @@ public class SelectCorrector extends BaseSemanticCorrector {
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL)); Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL)); Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
// decide whether add order by expression field to select
Environment environment = ContextUtils.getBean(Environment.class);
String correctorAdditionalInfo = environment.getProperty(ADDITIONAL_INFORMATION);
if (StringUtils.isNotBlank(correctorAdditionalInfo)
&& Boolean.parseBoolean(correctorAdditionalInfo)) {
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
}
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) { if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
return correctS2SQL; return correctS2SQL;
} }

View File

@@ -70,20 +70,24 @@ public class TimeRangeParser implements SemanticParser {
} }
private DateConf parseDateCN(String queryText) { private DateConf parseDateCN(String queryText) {
List<TimeNLP> times = TimeNLPUtil.parse(queryText); try {
if (times.isEmpty()) { List<TimeNLP> times = TimeNLPUtil.parse(queryText);
if (times.isEmpty()) {
return null;
}
Date startDate = times.get(0).getTime();
String detectWord = times.get(0).getTimeExpression();
Date endDate = times.size() > 1 ? times.get(1).getTime() : startDate;
if (times.size() > 1) {
detectWord += "~" + times.get(1).getTimeExpression();
}
return getDateConf(startDate, endDate, detectWord);
} catch (Exception e) {
return null; return null;
} }
Date startDate = times.get(0).getTime();
String detectWord = times.get(0).getTimeExpression();
Date endDate = times.size() > 1 ? times.get(1).getTime() : startDate;
if (times.size() > 1) {
detectWord += "~" + times.get(1).getTimeExpression();
}
return getDateConf(startDate, endDate, detectWord);
} }
private DateConf parseDateNumber(String queryText) { private DateConf parseDateNumber(String queryText) {

View File

@@ -3,17 +3,12 @@ package com.tencent.supersonic.headless.chat.corrector;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult; import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import org.junit.Assert; import org.junit.Assert;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.util.ArrayList; import java.util.ArrayList;
@@ -21,7 +16,6 @@ import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@Disabled
class SchemaCorrectorTest { class SchemaCorrectorTest {
private String json = "{\n" + " \"dataSetId\": 1,\n" + " \"llmReq\": {\n" private String json = "{\n" + " \"dataSetId\": 1,\n" + " \"llmReq\": {\n"
@@ -40,52 +34,54 @@ class SchemaCorrectorTest {
+ " },\n" + " \"request\": null\n" + "}"; + " },\n" + " \"request\": null\n" + "}";
@Test @Test
void doCorrect() throws JsonProcessingException { void testCorrectWrongColumnName() {
Long dataSetId = 1L; String sql = "SELECT 歌曲 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放) DESC LIMIT 10";
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId); ChatQueryContext chatQueryContext = buildQueryContext(sql);
SemanticParseInfo parseInfo = chatQueryContext.getCandidateQueries().get(0).getParseInfo();
SchemaCorrector schemaCorrector = new SchemaCorrector();
schemaCorrector.correct(chatQueryContext, parseInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY SUM(播放量) DESC LIMIT 10",
parseInfo.getSqlInfo().getCorrectedS2SQL());
}
@Test
void testRemoveUnmappedFilterValue() throws JsonProcessingException {
String sql = "SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10";
ChatQueryContext chatQueryContext = buildQueryContext(sql);
SemanticParseInfo parseInfo = chatQueryContext.getCandidateQueries().get(0).getParseInfo();
ObjectMapper objectMapper = new ObjectMapper(); ObjectMapper objectMapper = new ObjectMapper();
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class); ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);
String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' " parseInfo.getProperties().put(Constants.CONTEXT, parseResult);
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSetId(dataSetId);
semanticParseInfo.setDataSet(schemaElement);
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
SchemaCorrector schemaCorrector = new SchemaCorrector(); SchemaCorrector schemaCorrector = new SchemaCorrector();
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo); schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo);
Assert.assertEquals( Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' ORDER BY 播放量 DESC LIMIT 10",
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + "ORDER BY 播放量 DESC LIMIT 10", parseInfo.getSqlInfo().getCorrectedS2SQL());
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
parseResult = objectMapper.readValue(json, ParseResult.class);
List<LLMReq.ElementValue> linkingValues = new ArrayList<>(); List<LLMReq.ElementValue> linkingValues = new ArrayList<>();
LLMReq.ElementValue elementValue = new LLMReq.ElementValue(); LLMReq.ElementValue elementValue = new LLMReq.ElementValue();
elementValue.setFieldName("商务组"); elementValue.setFieldName("商务组");
elementValue.setFieldValue("xxx"); elementValue.setFieldValue("xxx");
linkingValues.add(elementValue); linkingValues.add(elementValue);
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult); parseResult.getLlmReq().getSchema().setValues(linkingValues);
parseInfo.getProperties().put(Constants.CONTEXT, parseResult);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(sql); parseInfo.getSqlInfo().setCorrectedS2SQL(sql);
semanticParseInfo.getSqlInfo().setParsedS2SQL(sql); parseInfo.getSqlInfo().setParsedS2SQL(sql);
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo); schemaCorrector.removeUnmappedFilterValue(chatQueryContext, parseInfo);
Assert.assertEquals( Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
"SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", + "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10",
semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); parseInfo.getSqlInfo().getCorrectedS2SQL());
} }
private ChatQueryContext buildQueryContext(Long dataSetId) { private ChatQueryContext buildQueryContext(String sql) {
Long dataSetId = 1L;
ChatQueryContext chatQueryContext = new ChatQueryContext(); ChatQueryContext chatQueryContext = new ChatQueryContext();
List<DataSetSchema> dataSetSchemaList = new ArrayList<>(); List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
DataSetSchema dataSetSchema = new DataSetSchema(); DataSetSchema dataSetSchema = new DataSetSchema();
@@ -94,27 +90,29 @@ class SchemaCorrectorTest {
SchemaElement schemaElement = new SchemaElement(); SchemaElement schemaElement = new SchemaElement();
schemaElement.setDataSetId(dataSetId); schemaElement.setDataSetId(dataSetId);
dataSetSchema.setDataSet(schemaElement); dataSetSchema.setDataSet(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>(); Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement(); dimensions.add(SchemaElement.builder().name("歌曲名").dataSetId(dataSetId).build());
element1.setDataSetId(1L); dimensions.add(SchemaElement.builder().name("商务组").dataSetId(dataSetId).build());
element1.setName("歌曲名"); dimensions.add(SchemaElement.builder().name("发行日期").dataSetId(dataSetId).build());
dimensions.add(element1); dimensions.add(SchemaElement.builder().name("播放量").dataSetId(dataSetId).build());
SchemaElement element2 = new SchemaElement();
element2.setDataSetId(1L);
element2.setName("商务组");
dimensions.add(element2);
SchemaElement element3 = new SchemaElement();
element3.setDataSetId(1L);
element3.setName("发行日期");
dimensions.add(element3);
dataSetSchema.setDimensions(dimensions); dataSetSchema.setDimensions(dimensions);
dataSetSchemaList.add(dataSetSchema); dataSetSchemaList.add(dataSetSchema);
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList); SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
chatQueryContext.setSemanticSchema(semanticSchema); chatQueryContext.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
semanticParseInfo.setDataSet(dataSetSchema.getDataSet());
LLMSqlQuery sqlQuery = new LLMSqlQuery();
sqlQuery.setParseInfo(semanticParseInfo);
chatQueryContext.getCandidateQueries().add(sqlQuery);
return chatQueryContext; return chatQueryContext;
} }
} }

View File

@@ -20,7 +20,7 @@ public class DefaultQueryCache implements QueryCache {
if (isCache(semanticQueryReq)) { if (isCache(semanticQueryReq)) {
Object result = cacheManager.get(cacheKey); Object result = cacheManager.get(cacheKey);
if (Objects.nonNull(result)) { if (Objects.nonNull(result)) {
log.info("query from cache, key:{},result:{}", cacheKey, log.debug("query from cache, key:{},result:{}", cacheKey,
StringUtils.normalizeSpace(result.toString())); StringUtils.normalizeSpace(result.toString()));
} }
return result; return result;

View File

@@ -45,7 +45,7 @@ public class JdbcExecutor implements QueryExecutor {
sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns); sqlUtil.queryInternal(queryStatement.getSql(), queryResultWithColumns);
queryResultWithColumns.setSql(sql); queryResultWithColumns.setSql(sql);
} catch (Exception e) { } catch (Exception e) {
log.error("queryInternal with error [{}]", StringUtils.normalizeSpace(e.getMessage())); log.error("queryInternal with error ", e);
queryResultWithColumns.setErrorMsg(e.getMessage()); queryResultWithColumns.setErrorMsg(e.getMessage());
} }
return queryResultWithColumns; return queryResultWithColumns;

View File

@@ -50,7 +50,8 @@ public class DefaultSemanticTranslator implements SemanticTranslator {
optimizer.rewrite(queryStatement); optimizer.rewrite(queryStatement);
} }
} }
log.info("translated query SQL: [{}]", StringUtils.normalizeSpace(queryStatement.getSql())); log.debug("translated query SQL: [{}]",
StringUtils.normalizeSpace(queryStatement.getSql()));
} }
private void mergeOntologyQuery(QueryStatement queryStatement) throws Exception { private void mergeOntologyQuery(QueryStatement queryStatement) throws Exception {

View File

@@ -288,8 +288,6 @@ public class S2SemanticLayerService implements SemanticLayerService {
queryStatement.setSql(semanticQueryReq.getSqlInfo().getQuerySQL()); queryStatement.setSql(semanticQueryReq.getSqlInfo().getQuerySQL());
queryStatement.setIsTranslated(true); queryStatement.setIsTranslated(true);
} }
queryStatement.setDataSetId(semanticQueryReq.getDataSetId());
queryStatement.setDataSetName(semanticQueryReq.getDataSetName());
return queryStatement; return queryStatement;
} }
@@ -321,6 +319,11 @@ public class S2SemanticLayerService implements SemanticLayerService {
Long dataSetId = dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user); Long dataSetId = dataSetService.getDataSetIdFromSql(querySqlReq.getSql(), user);
querySqlReq.setDataSetId(dataSetId); querySqlReq.setDataSetId(dataSetId);
} }
if (querySqlReq.getDataSetId() != null) {
DataSetResp dataSetResp = dataSetService.getDataSet(querySqlReq.getDataSetId());
queryStatement.setDataSetId(dataSetResp.getId());
queryStatement.setDataSetName(dataSetResp.getName());
}
return queryStatement; return queryStatement;
} }

View File

@@ -273,9 +273,10 @@ public class DictUtils {
private QuerySqlReq constructQuerySqlReq(DictItemResp dictItemResp) { private QuerySqlReq constructQuerySqlReq(DictItemResp dictItemResp) {
ModelResp model = modelService.getModel(dictItemResp.getModelId());
String sqlPattern = String sqlPattern =
"select %s,count(1) from tbl %s group by %s order by count(1) desc limit %d"; "select %s,count(1) from %s %s group by %s order by count(1) desc limit %d";
String bizName = dictItemResp.getBizName(); String dimBizName = dictItemResp.getBizName();
String whereStr = generateWhereStr(dictItemResp); String whereStr = generateWhereStr(dictItemResp);
String where = StringUtils.isEmpty(whereStr) ? "" : "WHERE" + whereStr; String where = StringUtils.isEmpty(whereStr) ? "" : "WHERE" + whereStr;
ItemValueConfig config = dictItemResp.getConfig(); ItemValueConfig config = dictItemResp.getConfig();
@@ -286,7 +287,8 @@ public class DictUtils {
limit = Integer.MAX_VALUE; limit = Integer.MAX_VALUE;
} }
String sql = String.format(sqlPattern, bizName, where, bizName, limit); String sql =
String.format(sqlPattern, dimBizName, model.getBizName(), where, dimBizName, limit);
Set<Long> modelIds = new HashSet<>(); Set<Long> modelIds = new HashSet<>();
modelIds.add(dictItemResp.getModelId()); modelIds.add(dictItemResp.getModelId());
QuerySqlReq querySqlReq = new QuerySqlReq(); QuerySqlReq querySqlReq = new QuerySqlReq();

View File

@@ -17,6 +17,7 @@ import static org.junit.Assert.assertTrue;
public class QueryBySqlTest extends BaseTest { public class QueryBySqlTest extends BaseTest {
@Test @Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testDetailQuery() throws Exception { public void testDetailQuery() throws Exception {
SemanticQueryResp semanticQueryResp = SemanticQueryResp semanticQueryResp =
queryBySql("SELECT 用户名,访问次数 FROM 超音数PVUV统计 WHERE 用户名='alice' "); queryBySql("SELECT 用户名,访问次数 FROM 超音数PVUV统计 WHERE 用户名='alice' ");