mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 05:43:51 +00:00
[improvement](chat) remove duplicates from multiple SQL identified by LLM. (#391)
This commit is contained in:
@@ -7,19 +7,23 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.chat.query.QueryManager;
|
import com.tencent.supersonic.chat.query.QueryManager;
|
||||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.MapUtils;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Service
|
@Service
|
||||||
public class LLMResponseService {
|
public class LLMResponseService {
|
||||||
|
|
||||||
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) {
|
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) {
|
||||||
if (Objects.isNull(weight)) {
|
if (Objects.isNull(weight)) {
|
||||||
weight = 0D;
|
weight = 0D;
|
||||||
@@ -51,4 +55,19 @@ public class LLMResponseService {
|
|||||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||||
return parseInfo;
|
return parseInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Map<String, Double> getDeduplicationSqlWeight(LLMResp llmResp) {
|
||||||
|
if (MapUtils.isEmpty(llmResp.getSqlWeight())) {
|
||||||
|
return llmResp.getSqlWeight();
|
||||||
|
}
|
||||||
|
Map<String, Double> result = new HashMap<>();
|
||||||
|
for (Map.Entry<String, Double> entry : llmResp.getSqlWeight().entrySet()) {
|
||||||
|
String key = entry.getKey();
|
||||||
|
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
result.put(key, entry.getValue());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.MapUtils;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class LLMS2SQLParser implements SemanticParser {
|
public class LLMS2SQLParser implements SemanticParser {
|
||||||
@@ -45,8 +46,9 @@ public class LLMS2SQLParser implements SemanticParser {
|
|||||||
if (Objects.isNull(llmResp)) {
|
if (Objects.isNull(llmResp)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
//5. get and update parserInfo
|
//5. deduplicate the SQL result list and build parserInfo
|
||||||
Map<String, Double> sqlWeight = llmResp.getSqlWeight();
|
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||||
|
Map<String, Double> deduplicationSqlWeight = responseService.getDeduplicationSqlWeight(llmResp);
|
||||||
ParseResult parseResult = ParseResult.builder()
|
ParseResult parseResult = ParseResult.builder()
|
||||||
.request(request)
|
.request(request)
|
||||||
.modelId(modelId)
|
.modelId(modelId)
|
||||||
@@ -56,12 +58,10 @@ public class LLMS2SQLParser implements SemanticParser {
|
|||||||
.linkingValues(linkingValues)
|
.linkingValues(linkingValues)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
if (MapUtils.isEmpty(deduplicationSqlWeight)) {
|
||||||
|
|
||||||
if (Objects.isNull(sqlWeight) || sqlWeight.isEmpty()) {
|
|
||||||
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
|
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
|
||||||
} else {
|
} else {
|
||||||
sqlWeight.forEach((sql, weight) -> {
|
deduplicationSqlWeight.forEach((sql, weight) -> {
|
||||||
responseService.addParseInfo(queryCtx, parseResult, sql, weight);
|
responseService.addParseInfo(queryCtx, parseResult, sql, weight);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
|||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
@@ -47,7 +47,7 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL);
|
List<FieldExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL);
|
||||||
//set dataInfo
|
//set dataInfo
|
||||||
try {
|
try {
|
||||||
if (!CollectionUtils.isEmpty(expressions)) {
|
if (!CollectionUtils.isEmpty(expressions)) {
|
||||||
@@ -112,9 +112,9 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
|||||||
|
|
||||||
|
|
||||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
||||||
List<FilterExpression> filterExpressions) {
|
List<FieldExpression> fieldExpressions) {
|
||||||
List<QueryFilter> result = Lists.newArrayList();
|
List<QueryFilter> result = Lists.newArrayList();
|
||||||
for (FilterExpression expression : filterExpressions) {
|
for (FieldExpression expression : fieldExpressions) {
|
||||||
QueryFilter dimensionFilter = new QueryFilter();
|
QueryFilter dimensionFilter = new QueryFilter();
|
||||||
dimensionFilter.setValue(expression.getFieldValue());
|
dimensionFilter.setValue(expression.getFieldValue());
|
||||||
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
|
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
|
||||||
@@ -133,8 +133,8 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
|
private DateConf getDateInfo(List<FieldExpression> fieldExpressions) {
|
||||||
List<FilterExpression> dateExpressions = filterExpressions.stream()
|
List<FieldExpression> dateExpressions = fieldExpressions.stream()
|
||||||
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
|
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
if (CollectionUtils.isEmpty(dateExpressions)) {
|
if (CollectionUtils.isEmpty(dateExpressions)) {
|
||||||
@@ -142,7 +142,7 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
|||||||
}
|
}
|
||||||
DateConf dateInfo = new DateConf();
|
DateConf dateInfo = new DateConf();
|
||||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||||
FilterExpression firstExpression = dateExpressions.get(0);
|
FieldExpression firstExpression = dateExpressions.get(0);
|
||||||
|
|
||||||
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
|
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
|
||||||
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
|
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
|
||||||
@@ -168,12 +168,12 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
|||||||
return dateInfo;
|
return dateInfo;
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
|
private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator,
|
||||||
FilterOperatorEnum... operatorEnums) {
|
FilterOperatorEnum... operatorEnums) {
|
||||||
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
|
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
|
private boolean hasSecondDate(List<FieldExpression> dateExpressions) {
|
||||||
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
|||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.DateUtils;
|
import com.tencent.supersonic.common.util.DateUtils;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||||
@@ -296,8 +296,8 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
log.info("correctorSql before replacing:{}", correctorSql);
|
log.info("correctorSql before replacing:{}", correctorSql);
|
||||||
// get where filter and having filter
|
// get where filter and having filter
|
||||||
List<FilterExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
|
List<FieldExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
|
||||||
List<FilterExpression> havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql);
|
List<FieldExpression> havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql);
|
||||||
List<Expression> addWhereConditions = new ArrayList<>();
|
List<Expression> addWhereConditions = new ArrayList<>();
|
||||||
List<Expression> addHavingConditions = new ArrayList<>();
|
List<Expression> addHavingConditions = new ArrayList<>();
|
||||||
Set<String> removeWhereFieldNames = new HashSet<>();
|
Set<String> removeWhereFieldNames = new HashSet<>();
|
||||||
@@ -350,7 +350,7 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
|
|
||||||
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
|
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
|
||||||
Map<String, Map<String, String>> filedNameToValueMap,
|
Map<String, Map<String, String>> filedNameToValueMap,
|
||||||
List<FilterExpression> filterExpressionList,
|
List<FieldExpression> fieldExpressionList,
|
||||||
List<Expression> addConditions,
|
List<Expression> addConditions,
|
||||||
Set<String> removeFieldNames) {
|
Set<String> removeFieldNames) {
|
||||||
if (Objects.isNull(queryData.getDateInfo())) {
|
if (Objects.isNull(queryData.getDateInfo())) {
|
||||||
@@ -364,12 +364,12 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
}
|
}
|
||||||
// startDate equals to endDate
|
// startDate equals to endDate
|
||||||
if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) {
|
if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) {
|
||||||
for (FilterExpression filterExpression : filterExpressionList) {
|
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||||
if (TimeDimensionEnum.DAY.getChName().equals(filterExpression.getFieldName())) {
|
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
|
||||||
//sql where condition exists 'equals' operator about date,just replace
|
//sql where condition exists 'equals' operator about date,just replace
|
||||||
if (filterExpression.getOperator().equals(FilterOperatorEnum.EQUALS)) {
|
if (fieldExpression.getOperator().equals(FilterOperatorEnum.EQUALS)) {
|
||||||
dateField = filterExpression.getFieldName();
|
dateField = fieldExpression.getFieldName();
|
||||||
map.put(filterExpression.getFieldValue().toString(),
|
map.put(fieldExpression.getFieldValue().toString(),
|
||||||
queryData.getDateInfo().getStartDate());
|
queryData.getDateInfo().getStartDate());
|
||||||
filedNameToValueMap.put(dateField, map);
|
filedNameToValueMap.put(dateField, map);
|
||||||
} else {
|
} else {
|
||||||
@@ -386,23 +386,23 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (FilterExpression filterExpression : filterExpressionList) {
|
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||||
if (TimeDimensionEnum.DAY.getChName().equals(filterExpression.getFieldName())) {
|
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
|
||||||
dateField = filterExpression.getFieldName();
|
dateField = fieldExpression.getFieldName();
|
||||||
//just replace
|
//just replace
|
||||||
if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(filterExpression.getOperator())
|
if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(fieldExpression.getOperator())
|
||||||
|| FilterOperatorEnum.GREATER_THAN.getValue().equals(filterExpression.getOperator())) {
|
|| FilterOperatorEnum.GREATER_THAN.getValue().equals(fieldExpression.getOperator())) {
|
||||||
map.put(filterExpression.getFieldValue().toString(),
|
map.put(fieldExpression.getFieldValue().toString(),
|
||||||
queryData.getDateInfo().getStartDate());
|
queryData.getDateInfo().getStartDate());
|
||||||
}
|
}
|
||||||
if (FilterOperatorEnum.MINOR_THAN_EQUALS.getValue().equals(filterExpression.getOperator())
|
if (FilterOperatorEnum.MINOR_THAN_EQUALS.getValue().equals(fieldExpression.getOperator())
|
||||||
|| FilterOperatorEnum.MINOR_THAN.getValue().equals(filterExpression.getOperator())) {
|
|| FilterOperatorEnum.MINOR_THAN.getValue().equals(fieldExpression.getOperator())) {
|
||||||
map.put(filterExpression.getFieldValue().toString(),
|
map.put(fieldExpression.getFieldValue().toString(),
|
||||||
queryData.getDateInfo().getEndDate());
|
queryData.getDateInfo().getEndDate());
|
||||||
}
|
}
|
||||||
filedNameToValueMap.put(dateField, map);
|
filedNameToValueMap.put(dateField, map);
|
||||||
// first remove,then add
|
// first remove,then add
|
||||||
if (FilterOperatorEnum.EQUALS.getValue().equals(filterExpression.getOperator())) {
|
if (FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator())) {
|
||||||
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
|
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
|
||||||
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
|
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
|
||||||
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
|
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
|
||||||
@@ -425,7 +425,7 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
addConditions.add(comparisonExpression);
|
addConditions.add(comparisonExpression);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateFilters(List<FilterExpression> filterExpressionList,
|
private void updateFilters(List<FieldExpression> fieldExpressionList,
|
||||||
Set<QueryFilter> metricFilters,
|
Set<QueryFilter> metricFilters,
|
||||||
Set<QueryFilter> contextMetricFilters,
|
Set<QueryFilter> contextMetricFilters,
|
||||||
List<Expression> addConditions,
|
List<Expression> addConditions,
|
||||||
@@ -434,9 +434,9 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
for (QueryFilter dslQueryFilter : metricFilters) {
|
for (QueryFilter dslQueryFilter : metricFilters) {
|
||||||
for (FilterExpression filterExpression : filterExpressionList) {
|
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||||
if (filterExpression.getFieldName() != null
|
if (fieldExpression.getFieldName() != null
|
||||||
&& filterExpression.getFieldName().contains(dslQueryFilter.getName())) {
|
&& fieldExpression.getFieldName().contains(dslQueryFilter.getName())) {
|
||||||
removeFieldNames.add(dslQueryFilter.getName());
|
removeFieldNames.add(dslQueryFilter.getName());
|
||||||
if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
|
if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
|
||||||
EqualsTo equalsTo = new EqualsTo();
|
EqualsTo equalsTo = new EqualsTo();
|
||||||
|
|||||||
@@ -0,0 +1,51 @@
|
|||||||
|
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
class LLMResponseServiceTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void deduplicationSqlWeight() {
|
||||||
|
String sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||||
|
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||||
|
|
||||||
|
LLMResp llmResp = new LLMResp();
|
||||||
|
Map<String, Double> sqlWeight = new HashMap<>();
|
||||||
|
sqlWeight.put(sql1, 0.2D);
|
||||||
|
sqlWeight.put(sql2, 0.8D);
|
||||||
|
llmResp.setSqlWeight(sqlWeight);
|
||||||
|
LLMResponseService llmResponseService = new LLMResponseService();
|
||||||
|
Map<String, Double> deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp);
|
||||||
|
|
||||||
|
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
|
||||||
|
|
||||||
|
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||||
|
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||||
|
|
||||||
|
LLMResp llmResp2 = new LLMResp();
|
||||||
|
Map<String, Double> sqlWeight2 = new HashMap<>();
|
||||||
|
sqlWeight2.put(sql1, 0.2D);
|
||||||
|
sqlWeight2.put(sql2, 0.8D);
|
||||||
|
llmResp2.setSqlWeight(sqlWeight2);
|
||||||
|
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp2);
|
||||||
|
|
||||||
|
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
|
||||||
|
|
||||||
|
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||||
|
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||||
|
|
||||||
|
LLMResp llmResp3 = new LLMResp();
|
||||||
|
Map<String, Double> sqlWeight3 = new HashMap<>();
|
||||||
|
sqlWeight3.put(sql1, 0.2D);
|
||||||
|
sqlWeight3.put(sql2, 0.8D);
|
||||||
|
llmResp3.setSqlWeight(sqlWeight3);
|
||||||
|
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp3);
|
||||||
|
|
||||||
|
Assert.assertEquals(deduplicationSqlWeight.size(), 2);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -28,38 +28,38 @@ import org.apache.commons.collections.CollectionUtils;
|
|||||||
|
|
||||||
public class FieldAndValueAcquireVisitor extends ExpressionVisitorAdapter {
|
public class FieldAndValueAcquireVisitor extends ExpressionVisitorAdapter {
|
||||||
|
|
||||||
private Set<FilterExpression> filterExpressions;
|
private Set<FieldExpression> fieldExpressions;
|
||||||
|
|
||||||
public FieldAndValueAcquireVisitor(Set<FilterExpression> filterExpressions) {
|
public FieldAndValueAcquireVisitor(Set<FieldExpression> fieldExpressions) {
|
||||||
this.filterExpressions = filterExpressions;
|
this.fieldExpressions = fieldExpressions;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void visit(LikeExpression expr) {
|
public void visit(LikeExpression expr) {
|
||||||
Expression leftExpression = expr.getLeftExpression();
|
Expression leftExpression = expr.getLeftExpression();
|
||||||
Expression rightExpression = expr.getRightExpression();
|
Expression rightExpression = expr.getRightExpression();
|
||||||
|
|
||||||
FilterExpression filterExpression = new FilterExpression();
|
FieldExpression fieldExpression = new FieldExpression();
|
||||||
String columnName = null;
|
String columnName = null;
|
||||||
if (leftExpression instanceof Column) {
|
if (leftExpression instanceof Column) {
|
||||||
Column column = (Column) leftExpression;
|
Column column = (Column) leftExpression;
|
||||||
columnName = column.getColumnName();
|
columnName = column.getColumnName();
|
||||||
filterExpression.setFieldName(columnName);
|
fieldExpression.setFieldName(columnName);
|
||||||
}
|
}
|
||||||
filterExpression.setFieldValue(getFieldValue(rightExpression));
|
fieldExpression.setFieldValue(getFieldValue(rightExpression));
|
||||||
filterExpression.setOperator(expr.getStringExpression());
|
fieldExpression.setOperator(expr.getStringExpression());
|
||||||
filterExpressions.add(filterExpression);
|
fieldExpressions.add(fieldExpression);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void visit(InExpression expr) {
|
public void visit(InExpression expr) {
|
||||||
FilterExpression filterExpression = new FilterExpression();
|
FieldExpression fieldExpression = new FieldExpression();
|
||||||
Expression leftExpression = expr.getLeftExpression();
|
Expression leftExpression = expr.getLeftExpression();
|
||||||
if (!(leftExpression instanceof Column)) {
|
if (!(leftExpression instanceof Column)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
filterExpression.setFieldName(((Column) leftExpression).getColumnName());
|
fieldExpression.setFieldName(((Column) leftExpression).getColumnName());
|
||||||
filterExpression.setOperator(JsqlConstants.IN);
|
fieldExpression.setOperator(JsqlConstants.IN);
|
||||||
ItemsList rightItemsList = expr.getRightItemsList();
|
ItemsList rightItemsList = expr.getRightItemsList();
|
||||||
filterExpression.setFieldValue(rightItemsList);
|
fieldExpression.setFieldValue(rightItemsList);
|
||||||
List<Object> result = new ArrayList<>();
|
List<Object> result = new ArrayList<>();
|
||||||
if (rightItemsList instanceof ExpressionList) {
|
if (rightItemsList instanceof ExpressionList) {
|
||||||
ExpressionList rightExpressionList = (ExpressionList) rightItemsList;
|
ExpressionList rightExpressionList = (ExpressionList) rightItemsList;
|
||||||
@@ -70,78 +70,78 @@ public class FieldAndValueAcquireVisitor extends ExpressionVisitorAdapter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
filterExpression.setFieldValue(result);
|
fieldExpression.setFieldValue(result);
|
||||||
filterExpressions.add(filterExpression);
|
fieldExpressions.add(fieldExpression);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void visit(MinorThan expr) {
|
public void visit(MinorThan expr) {
|
||||||
FilterExpression filterExpression = getFilterExpression(expr);
|
FieldExpression fieldExpression = getFilterExpression(expr);
|
||||||
filterExpressions.add(filterExpression);
|
fieldExpressions.add(fieldExpression);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void visit(EqualsTo expr) {
|
public void visit(EqualsTo expr) {
|
||||||
FilterExpression filterExpression = getFilterExpression(expr);
|
FieldExpression fieldExpression = getFilterExpression(expr);
|
||||||
filterExpressions.add(filterExpression);
|
fieldExpressions.add(fieldExpression);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void visit(MinorThanEquals expr) {
|
public void visit(MinorThanEquals expr) {
|
||||||
FilterExpression filterExpression = getFilterExpression(expr);
|
FieldExpression fieldExpression = getFilterExpression(expr);
|
||||||
filterExpressions.add(filterExpression);
|
fieldExpressions.add(fieldExpression);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void visit(GreaterThan expr) {
|
public void visit(GreaterThan expr) {
|
||||||
FilterExpression filterExpression = getFilterExpression(expr);
|
FieldExpression fieldExpression = getFilterExpression(expr);
|
||||||
filterExpressions.add(filterExpression);
|
fieldExpressions.add(fieldExpression);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void visit(GreaterThanEquals expr) {
|
public void visit(GreaterThanEquals expr) {
|
||||||
FilterExpression filterExpression = getFilterExpression(expr);
|
FieldExpression fieldExpression = getFilterExpression(expr);
|
||||||
filterExpressions.add(filterExpression);
|
fieldExpressions.add(fieldExpression);
|
||||||
}
|
}
|
||||||
|
|
||||||
private FilterExpression getFilterExpression(ComparisonOperator expr) {
|
private FieldExpression getFilterExpression(ComparisonOperator expr) {
|
||||||
Expression leftExpression = expr.getLeftExpression();
|
Expression leftExpression = expr.getLeftExpression();
|
||||||
Expression rightExpression = expr.getRightExpression();
|
Expression rightExpression = expr.getRightExpression();
|
||||||
|
|
||||||
FilterExpression filterExpression = new FilterExpression();
|
FieldExpression fieldExpression = new FieldExpression();
|
||||||
String columnName = null;
|
String columnName = null;
|
||||||
if (leftExpression instanceof Column) {
|
if (leftExpression instanceof Column) {
|
||||||
Column column = (Column) leftExpression;
|
Column column = (Column) leftExpression;
|
||||||
columnName = column.getColumnName();
|
columnName = column.getColumnName();
|
||||||
filterExpression.setFieldName(columnName);
|
fieldExpression.setFieldName(columnName);
|
||||||
}
|
}
|
||||||
if (leftExpression instanceof Function) {
|
if (leftExpression instanceof Function) {
|
||||||
Function leftExpressionFunction = (Function) leftExpression;
|
Function leftExpressionFunction = (Function) leftExpression;
|
||||||
Column field = getColumn(leftExpressionFunction);
|
Column field = getColumn(leftExpressionFunction);
|
||||||
if (Objects.isNull(field)) {
|
if (Objects.isNull(field)) {
|
||||||
return filterExpression;
|
return fieldExpression;
|
||||||
}
|
}
|
||||||
String functionName = leftExpressionFunction.getName().toUpperCase();
|
String functionName = leftExpressionFunction.getName().toUpperCase();
|
||||||
filterExpression.setFieldName(field.getColumnName());
|
fieldExpression.setFieldName(field.getColumnName());
|
||||||
filterExpression.setFunction(functionName);
|
fieldExpression.setFunction(functionName);
|
||||||
filterExpression.setOperator(expr.getStringExpression());
|
fieldExpression.setOperator(expr.getStringExpression());
|
||||||
//deal with DAY/WEEK function
|
//deal with DAY/WEEK function
|
||||||
List<DatePeriodEnum> collect = Arrays.stream(DatePeriodEnum.values()).collect(Collectors.toList());
|
List<DatePeriodEnum> collect = Arrays.stream(DatePeriodEnum.values()).collect(Collectors.toList());
|
||||||
DatePeriodEnum periodEnum = DatePeriodEnum.get(functionName);
|
DatePeriodEnum periodEnum = DatePeriodEnum.get(functionName);
|
||||||
if (Objects.nonNull(periodEnum) && collect.contains(periodEnum)) {
|
if (Objects.nonNull(periodEnum) && collect.contains(periodEnum)) {
|
||||||
filterExpression.setFieldValue(getFieldValue(rightExpression) + periodEnum.getChName());
|
fieldExpression.setFieldValue(getFieldValue(rightExpression) + periodEnum.getChName());
|
||||||
return filterExpression;
|
return fieldExpression;
|
||||||
} else {
|
} else {
|
||||||
//deal with aggregate function
|
//deal with aggregate function
|
||||||
filterExpression.setFieldValue(getFieldValue(rightExpression));
|
fieldExpression.setFieldValue(getFieldValue(rightExpression));
|
||||||
return filterExpression;
|
return fieldExpression;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
filterExpression.setFieldValue(getFieldValue(rightExpression));
|
fieldExpression.setFieldValue(getFieldValue(rightExpression));
|
||||||
filterExpression.setOperator(expr.getStringExpression());
|
fieldExpression.setOperator(expr.getStringExpression());
|
||||||
return filterExpression;
|
return fieldExpression;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Column getColumn(Function leftExpressionFunction) {
|
private Column getColumn(Function leftExpressionFunction) {
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.common.util.jsqlparser;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class FilterExpression {
|
public class FieldExpression {
|
||||||
|
|
||||||
private String operator;
|
private String operator;
|
||||||
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.common.util.jsqlparser;
|
package com.tencent.supersonic.common.util.jsqlparser;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
@@ -10,27 +11,34 @@ import net.sf.jsqlparser.statement.select.OrderByVisitorAdapter;
|
|||||||
|
|
||||||
public class OrderByAcquireVisitor extends OrderByVisitorAdapter {
|
public class OrderByAcquireVisitor extends OrderByVisitorAdapter {
|
||||||
|
|
||||||
private Set<String> fields;
|
private Set<FieldExpression> fields;
|
||||||
|
|
||||||
public OrderByAcquireVisitor(Set<String> fields) {
|
public OrderByAcquireVisitor(Set<FieldExpression> fields) {
|
||||||
this.fields = fields;
|
this.fields = fields;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void visit(OrderByElement orderBy) {
|
public void visit(OrderByElement orderBy) {
|
||||||
Expression expression = orderBy.getExpression();
|
Expression expression = orderBy.getExpression();
|
||||||
|
FieldExpression fieldExpression = new FieldExpression();
|
||||||
if (expression instanceof Column) {
|
if (expression instanceof Column) {
|
||||||
fields.add(((Column) expression).getColumnName());
|
fieldExpression.setFieldName(((Column) expression).getColumnName());
|
||||||
}
|
}
|
||||||
if (expression instanceof Function) {
|
if (expression instanceof Function) {
|
||||||
Function function = (Function) expression;
|
Function function = (Function) expression;
|
||||||
List<Expression> expressions = function.getParameters().getExpressions();
|
List<Expression> expressions = function.getParameters().getExpressions();
|
||||||
for (Expression column : expressions) {
|
for (Expression column : expressions) {
|
||||||
if (column instanceof Column) {
|
if (column instanceof Column) {
|
||||||
fields.add(((Column) column).getColumnName());
|
fieldExpression.setFieldName(((Column) column).getColumnName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
String operator = Constants.ASC_UPPER;
|
||||||
|
if (!orderBy.isAsc()) {
|
||||||
|
operator = Constants.DESC_UPPER;
|
||||||
|
}
|
||||||
|
fieldExpression.setOperator(operator);
|
||||||
|
fields.add(fieldExpression);
|
||||||
super.visit(orderBy);
|
super.visit(orderBy);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
package com.tencent.supersonic.common.util.jsqlparser;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class OrderByExpression {
|
||||||
|
|
||||||
|
private String operator;
|
||||||
|
|
||||||
|
private String fieldName;
|
||||||
|
|
||||||
|
private Object fieldValue;
|
||||||
|
|
||||||
|
private String function;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
package com.tencent.supersonic.common.util.jsqlparser;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sql Parser equal Helper
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class SqlParserEqualHelper {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* determine if two SQL statements are equal.
|
||||||
|
*
|
||||||
|
* @param thisSql
|
||||||
|
* @param otherSql
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static boolean equals(String thisSql, String otherSql) {
|
||||||
|
//1. select fields
|
||||||
|
List<String> thisSelectFields = SqlParserSelectHelper.getSelectFields(thisSql);
|
||||||
|
List<String> otherSelectFields = SqlParserSelectHelper.getSelectFields(otherSql);
|
||||||
|
|
||||||
|
if (!CollectionUtils.isEqualCollection(thisSelectFields, otherSelectFields)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
//2. all fields
|
||||||
|
List<String> thisAllFields = SqlParserSelectHelper.getAllFields(thisSql);
|
||||||
|
List<String> otherAllFields = SqlParserSelectHelper.getAllFields(otherSql);
|
||||||
|
|
||||||
|
if (!CollectionUtils.isEqualCollection(thisAllFields, otherAllFields)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
//3. where
|
||||||
|
List<FieldExpression> thisFieldExpressions = SqlParserSelectHelper.getFilterExpression(thisSql);
|
||||||
|
List<FieldExpression> otherFieldExpressions = SqlParserSelectHelper.getFilterExpression(otherSql);
|
||||||
|
|
||||||
|
if (!CollectionUtils.isEqualCollection(thisFieldExpressions, otherFieldExpressions)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
//4. tableName
|
||||||
|
if (!SqlParserSelectHelper.getDbTableName(thisSql)
|
||||||
|
.equalsIgnoreCase(SqlParserSelectHelper.getDbTableName(otherSql))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
//5. having
|
||||||
|
List<FieldExpression> thisHavingExpressions = SqlParserSelectHelper.getHavingExpressions(thisSql);
|
||||||
|
List<FieldExpression> otherHavingExpressions = SqlParserSelectHelper.getHavingExpressions(otherSql);
|
||||||
|
|
||||||
|
if (!CollectionUtils.isEqualCollection(thisHavingExpressions, otherHavingExpressions)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
//6. orderBy
|
||||||
|
List<FieldExpression> thisOrderByExpressions = SqlParserSelectHelper.getOrderByExpressions(thisSql);
|
||||||
|
List<FieldExpression> otherOrderByExpressions = SqlParserSelectHelper.getOrderByExpressions(otherSql);
|
||||||
|
|
||||||
|
if (!CollectionUtils.isEqualCollection(thisOrderByExpressions, otherOrderByExpressions)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@@ -5,6 +5,7 @@ import java.util.HashSet;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
@@ -40,12 +41,12 @@ import org.springframework.util.CollectionUtils;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class SqlParserSelectHelper {
|
public class SqlParserSelectHelper {
|
||||||
|
|
||||||
public static List<FilterExpression> getFilterExpression(String sql) {
|
public static List<FieldExpression> getFilterExpression(String sql) {
|
||||||
PlainSelect plainSelect = getPlainSelect(sql);
|
PlainSelect plainSelect = getPlainSelect(sql);
|
||||||
if (Objects.isNull(plainSelect)) {
|
if (Objects.isNull(plainSelect)) {
|
||||||
return new ArrayList<>();
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
Set<FilterExpression> result = new HashSet<>();
|
Set<FieldExpression> result = new HashSet<>();
|
||||||
Expression where = plainSelect.getWhere();
|
Expression where = plainSelect.getWhere();
|
||||||
if (Objects.nonNull(where)) {
|
if (Objects.nonNull(where)) {
|
||||||
where.accept(new FieldAndValueAcquireVisitor(result));
|
where.accept(new FieldAndValueAcquireVisitor(result));
|
||||||
@@ -208,12 +209,12 @@ public class SqlParserSelectHelper {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<FilterExpression> getWhereExpressions(String sql) {
|
public static List<FieldExpression> getWhereExpressions(String sql) {
|
||||||
PlainSelect plainSelect = getPlainSelect(sql);
|
PlainSelect plainSelect = getPlainSelect(sql);
|
||||||
if (Objects.isNull(plainSelect)) {
|
if (Objects.isNull(plainSelect)) {
|
||||||
return new ArrayList<>();
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
Set<FilterExpression> result = new HashSet<>();
|
Set<FieldExpression> result = new HashSet<>();
|
||||||
Expression where = plainSelect.getWhere();
|
Expression where = plainSelect.getWhere();
|
||||||
if (Objects.nonNull(where)) {
|
if (Objects.nonNull(where)) {
|
||||||
where.accept(new FieldAndValueAcquireVisitor(result));
|
where.accept(new FieldAndValueAcquireVisitor(result));
|
||||||
@@ -221,12 +222,12 @@ public class SqlParserSelectHelper {
|
|||||||
return new ArrayList<>(result);
|
return new ArrayList<>(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<FilterExpression> getHavingExpressions(String sql) {
|
public static List<FieldExpression> getHavingExpressions(String sql) {
|
||||||
PlainSelect plainSelect = getPlainSelect(sql);
|
PlainSelect plainSelect = getPlainSelect(sql);
|
||||||
if (Objects.isNull(plainSelect)) {
|
if (Objects.isNull(plainSelect)) {
|
||||||
return new ArrayList<>();
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
Set<FilterExpression> result = new HashSet<>();
|
Set<FieldExpression> result = new HashSet<>();
|
||||||
Expression having = plainSelect.getHaving();
|
Expression having = plainSelect.getHaving();
|
||||||
if (Objects.nonNull(having)) {
|
if (Objects.nonNull(having)) {
|
||||||
having.accept(new FieldAndValueAcquireVisitor(result));
|
having.accept(new FieldAndValueAcquireVisitor(result));
|
||||||
@@ -244,13 +245,31 @@ public class SqlParserSelectHelper {
|
|||||||
return new ArrayList<>(result);
|
return new ArrayList<>(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void getOrderByFields(PlainSelect plainSelect, Set<String> result) {
|
private static Set<FieldExpression> getOrderByFields(PlainSelect plainSelect) {
|
||||||
|
Set<FieldExpression> result = new HashSet<>();
|
||||||
List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
|
List<OrderByElement> orderByElements = plainSelect.getOrderByElements();
|
||||||
if (!CollectionUtils.isEmpty(orderByElements)) {
|
if (!CollectionUtils.isEmpty(orderByElements)) {
|
||||||
for (OrderByElement orderByElement : orderByElements) {
|
for (OrderByElement orderByElement : orderByElements) {
|
||||||
orderByElement.accept(new OrderByAcquireVisitor(result));
|
orderByElement.accept(new OrderByAcquireVisitor(result));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void getOrderByFields(PlainSelect plainSelect, Set<String> result) {
|
||||||
|
Set<FieldExpression> orderByFieldExpressions = getOrderByFields(plainSelect);
|
||||||
|
Set<String> collect = orderByFieldExpressions.stream()
|
||||||
|
.map(fieldExpression -> fieldExpression.getFieldName())
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
result.addAll(collect);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<FieldExpression> getOrderByExpressions(String sql) {
|
||||||
|
PlainSelect plainSelect = getPlainSelect(sql);
|
||||||
|
if (Objects.isNull(plainSelect)) {
|
||||||
|
return new ArrayList<>();
|
||||||
|
}
|
||||||
|
return new ArrayList<>(getOrderByFields(plainSelect));
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<String> getGroupByFields(String sql) {
|
public static List<String> getGroupByFields(String sql) {
|
||||||
|
|||||||
@@ -77,8 +77,7 @@ class DateUtilsTest {
|
|||||||
String startDate = "2023-07-01";
|
String startDate = "2023-07-01";
|
||||||
String endDate = "2023-10-01";
|
String endDate = "2023-10-01";
|
||||||
List<String> actualDateList = DateUtils.getDateList(startDate, endDate, Constants.MONTH);
|
List<String> actualDateList = DateUtils.getDateList(startDate, endDate, Constants.MONTH);
|
||||||
List<String> expectedDateList = Lists.newArrayList("2023-07-01", "2023-08-01",
|
List<String> expectedDateList = Lists.newArrayList("2023-07", "2023-08", "2023-09", "2023-10");
|
||||||
"2023-09-01", "2023-10-01");
|
|
||||||
Assertions.assertEquals(actualDateList, expectedDateList);
|
Assertions.assertEquals(actualDateList, expectedDateList);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
package com.tencent.supersonic.common.util.jsqlparser;
|
||||||
|
|
||||||
|
|
||||||
|
import cn.hutool.core.lang.Assert;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author lex luo
|
||||||
|
* @date 2023/11/15 15:04
|
||||||
|
*/
|
||||||
|
class SqlParserEqualHelperTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testEquals() {
|
||||||
|
String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2";
|
||||||
|
String sql2 = "SELECT * FROM table1 WHERE column2 = 2 AND column1 = 1";
|
||||||
|
Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), true);
|
||||||
|
|
||||||
|
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||||
|
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||||
|
Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), true);
|
||||||
|
|
||||||
|
sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a";
|
||||||
|
|
||||||
|
sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a";
|
||||||
|
|
||||||
|
Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), true);
|
||||||
|
|
||||||
|
sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a";
|
||||||
|
|
||||||
|
sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a";
|
||||||
|
|
||||||
|
Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), true);
|
||||||
|
|
||||||
|
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||||
|
sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||||
|
Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), false);
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -18,106 +18,106 @@ class SqlParserSelectHelperTest {
|
|||||||
"select 用户名, 访问次数 from 超音数 where 用户名 in ('alice', 'lucy')");
|
"select 用户名, 访问次数 from 超音数 where 用户名 in ('alice', 'lucy')");
|
||||||
System.out.println(selectStatement);
|
System.out.println(selectStatement);
|
||||||
|
|
||||||
List<FilterExpression> filterExpression = SqlParserSelectHelper.getFilterExpression(
|
List<FieldExpression> fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, user_id, field_a FROM s2 WHERE "
|
"SELECT department, user_id, field_a FROM s2 WHERE "
|
||||||
+ "sys_imp_date = '2023-08-08' AND YEAR(publish_date) = 2023 "
|
+ "sys_imp_date = '2023-08-08' AND YEAR(publish_date) = 2023 "
|
||||||
+ " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1");
|
+ " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' "
|
"SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' "
|
||||||
+ " AND YEAR(publish_date) = 2023 "
|
+ " AND YEAR(publish_date) = 2023 "
|
||||||
+ " AND MONTH(publish_date) = 8"
|
+ " AND MONTH(publish_date) = 8"
|
||||||
+ " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1");
|
+ " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08'"
|
"SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08'"
|
||||||
+ " AND YEAR(publish_date) = 2023 "
|
+ " AND YEAR(publish_date) = 2023 "
|
||||||
+ " AND MONTH(publish_date) = 8 AND DAY(publish_date) =20 "
|
+ " AND MONTH(publish_date) = 8 AND DAY(publish_date) =20 "
|
||||||
+ " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1");
|
+ " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1");
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' "
|
"SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' "
|
||||||
+ " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1");
|
+ " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' "
|
"SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' "
|
||||||
+ "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1");
|
+ "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' "
|
"SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' "
|
||||||
+ "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1");
|
+ "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, user_id, field_a FROM s2 WHERE "
|
"SELECT department, user_id, field_a FROM s2 WHERE "
|
||||||
+ "user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1");
|
+ "user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, user_id, field_a FROM s2 WHERE "
|
"SELECT department, user_id, field_a FROM s2 WHERE "
|
||||||
+ "user_id = 'alice' AND publish_date > 10000 ORDER BY pv DESC LIMIT 1");
|
+ "user_id = 'alice' AND publish_date > 10000 ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, user_id, field_a FROM s2 WHERE "
|
"SELECT department, user_id, field_a FROM s2 WHERE "
|
||||||
+ "user_id like '%alice%' AND publish_date > 10000 ORDER BY pv DESC LIMIT 1");
|
+ "user_id like '%alice%' AND publish_date > 10000 ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, pv FROM s2 WHERE "
|
"SELECT department, pv FROM s2 WHERE "
|
||||||
+ "user_id like '%alice%' AND publish_date > 10000 "
|
+ "user_id like '%alice%' AND publish_date > 10000 "
|
||||||
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
|
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, pv FROM s2 WHERE "
|
"SELECT department, pv FROM s2 WHERE "
|
||||||
+ "(user_id like '%alice%' AND publish_date > 10000) and sys_imp_date = '2023-08-08' "
|
+ "(user_id like '%alice%' AND publish_date > 10000) and sys_imp_date = '2023-08-08' "
|
||||||
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
|
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, pv FROM s2 WHERE "
|
"SELECT department, pv FROM s2 WHERE "
|
||||||
+ "(user_id like '%alice%' AND publish_date > 10000) and song_name in "
|
+ "(user_id like '%alice%' AND publish_date > 10000) and song_name in "
|
||||||
+ "('七里香','晴天') and sys_imp_date = '2023-08-08' "
|
+ "('七里香','晴天') and sys_imp_date = '2023-08-08' "
|
||||||
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
|
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, pv FROM s2 WHERE "
|
"SELECT department, pv FROM s2 WHERE "
|
||||||
+ "(user_id like '%alice%' AND publish_date > 10000) and song_name in (1,2) "
|
+ "(user_id like '%alice%' AND publish_date > 10000) and song_name in (1,2) "
|
||||||
+ "and sys_imp_date = '2023-08-08' "
|
+ "and sys_imp_date = '2023-08-08' "
|
||||||
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
|
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression(
|
fieldExpression = SqlParserSelectHelper.getFilterExpression(
|
||||||
"SELECT department, pv FROM s2 WHERE "
|
"SELECT department, pv FROM s2 WHERE "
|
||||||
+ "(user_id like '%alice%' AND publish_date > 10000) and 1 in (1) "
|
+ "(user_id like '%alice%' AND publish_date > 10000) and 1 in (1) "
|
||||||
+ "and sys_imp_date = '2023-08-08' "
|
+ "and sys_imp_date = '2023-08-08' "
|
||||||
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
|
+ "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
|
|
||||||
filterExpression = SqlParserSelectHelper.getFilterExpression("SELECT sum(销量) / (SELECT sum(销量) FROM 营销月模型 "
|
fieldExpression = SqlParserSelectHelper.getFilterExpression("SELECT sum(销量) / (SELECT sum(销量) FROM 营销月模型 "
|
||||||
+ "WHERE MONTH(数据日期) = 9) FROM 营销月模型 WHERE 国家中文名 = '肯尼亚' AND MONTH(数据日期) = 9");
|
+ "WHERE MONTH(数据日期) = 9) FROM 营销月模型 WHERE 国家中文名 = '肯尼亚' AND MONTH(数据日期) = 9");
|
||||||
|
|
||||||
System.out.println(filterExpression);
|
System.out.println(fieldExpression);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.semantic.query.utils;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.semantic.api.model.pojo.DimValueMap;
|
import com.tencent.supersonic.semantic.api.model.pojo.DimValueMap;
|
||||||
@@ -60,11 +60,11 @@ public class DimValueAspect {
|
|||||||
String sql = queryS2SQLReq.getSql();
|
String sql = queryS2SQLReq.getSql();
|
||||||
log.info("correctorSql before replacing:{}", sql);
|
log.info("correctorSql before replacing:{}", sql);
|
||||||
// if dimensionvalue is alias,consider the true dimensionvalue.
|
// if dimensionvalue is alias,consider the true dimensionvalue.
|
||||||
List<FilterExpression> filterExpressionList = SqlParserSelectHelper.getWhereExpressions(sql);
|
List<FieldExpression> fieldExpressionList = SqlParserSelectHelper.getWhereExpressions(sql);
|
||||||
List<DimensionResp> dimensions = dimensionService.getDimensions(metaFilter);
|
List<DimensionResp> dimensions = dimensionService.getDimensions(metaFilter);
|
||||||
Set<String> fieldNames = dimensions.stream().map(o -> o.getName()).collect(Collectors.toSet());
|
Set<String> fieldNames = dimensions.stream().map(o -> o.getName()).collect(Collectors.toSet());
|
||||||
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
|
||||||
filterExpressionList.stream().forEach(expression -> {
|
fieldExpressionList.stream().forEach(expression -> {
|
||||||
if (fieldNames.contains(expression.getFieldName())) {
|
if (fieldNames.contains(expression.getFieldName())) {
|
||||||
dimensions.stream().forEach(dimension -> {
|
dimensions.stream().forEach(dimension -> {
|
||||||
if (expression.getFieldName().equals(dimension.getName())
|
if (expression.getFieldName().equals(dimension.getName())
|
||||||
@@ -98,7 +98,7 @@ public class DimValueAspect {
|
|||||||
return queryResultWithColumns;
|
return queryResultWithColumns;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void replaceInCondition(FilterExpression expression, DimensionResp dimension,
|
public void replaceInCondition(FieldExpression expression, DimensionResp dimension,
|
||||||
Map<String, Map<String, String>> filedNameToValueMap) {
|
Map<String, Map<String, String>> filedNameToValueMap) {
|
||||||
if (expression.getOperator().equals(FilterOperatorEnum.IN.getValue())) {
|
if (expression.getOperator().equals(FilterOperatorEnum.IN.getValue())) {
|
||||||
String fieldValue = JsonUtil.toString(expression.getFieldValue());
|
String fieldValue = JsonUtil.toString(expression.getFieldValue());
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
|||||||
import com.tencent.supersonic.common.util.DateModeUtils;
|
import com.tencent.supersonic.common.util.DateModeUtils;
|
||||||
import com.tencent.supersonic.common.util.SqlFilterUtils;
|
import com.tencent.supersonic.common.util.SqlFilterUtils;
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
@@ -424,13 +424,13 @@ public class QueryStructUtils {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public DateConf getDateConfBySql(String sql) {
|
public DateConf getDateConfBySql(String sql) {
|
||||||
List<FilterExpression> filterExpressions = SqlParserSelectHelper.getFilterExpression(sql);
|
List<FieldExpression> fieldExpressions = SqlParserSelectHelper.getFilterExpression(sql);
|
||||||
if (!CollectionUtils.isEmpty(filterExpressions)) {
|
if (!CollectionUtils.isEmpty(fieldExpressions)) {
|
||||||
Set<String> dateList = new HashSet<>();
|
Set<String> dateList = new HashSet<>();
|
||||||
String startDate = "";
|
String startDate = "";
|
||||||
String endDate = "";
|
String endDate = "";
|
||||||
String period = "";
|
String period = "";
|
||||||
for (FilterExpression f : filterExpressions) {
|
for (FieldExpression f : fieldExpressions) {
|
||||||
if (Objects.isNull(f.getFieldName()) || !internalCols.contains(f.getFieldName().toLowerCase())) {
|
if (Objects.isNull(f.getFieldName()) || !internalCols.contains(f.getFieldName().toLowerCase())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user