diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseService.java index 9df25f859..6210989eb 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseService.java @@ -7,19 +7,23 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper; import com.tencent.supersonic.knowledge.service.SchemaService; import java.util.HashMap; import java.util.Map; import java.util.Objects; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.MapUtils; import org.springframework.stereotype.Service; @Slf4j @Service public class LLMResponseService { + public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) { if (Objects.isNull(weight)) { weight = 0D; @@ -51,4 +55,19 @@ public class LLMResponseService { queryCtx.getCandidateQueries().add(semanticQuery); return parseInfo; } + + public Map getDeduplicationSqlWeight(LLMResp llmResp) { + if (MapUtils.isEmpty(llmResp.getSqlWeight())) { + return llmResp.getSqlWeight(); + } + Map result = new HashMap<>(); + for (Map.Entry entry : llmResp.getSqlWeight().entrySet()) { + String key = entry.getKey(); + if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) { + continue; + } + result.put(key, entry.getValue()); + } + return result; + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java index 948018c35..dec947b72 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMS2SQLParser.java @@ -13,6 +13,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.MapUtils; @Slf4j public class LLMS2SQLParser implements SemanticParser { @@ -45,8 +46,9 @@ public class LLMS2SQLParser implements SemanticParser { if (Objects.isNull(llmResp)) { return; } - //5. get and update parserInfo - Map sqlWeight = llmResp.getSqlWeight(); + //5. deduplicate the SQL result list and build parserInfo + LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); + Map deduplicationSqlWeight = responseService.getDeduplicationSqlWeight(llmResp); ParseResult parseResult = ParseResult.builder() .request(request) .modelId(modelId) @@ -56,12 +58,10 @@ public class LLMS2SQLParser implements SemanticParser { .linkingValues(linkingValues) .build(); - LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); - - if (Objects.isNull(sqlWeight) || sqlWeight.isEmpty()) { + if (MapUtils.isEmpty(deduplicationSqlWeight)) { responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D); } else { - sqlWeight.forEach((sql, weight) -> { + deduplicationSqlWeight.forEach((sql, weight) -> { responseService.addParseInfo(queryCtx, parseResult, sql, weight); }); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java index d1f328b23..f3746219c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/ParserInfoServiceImpl.java @@ -13,7 +13,7 @@ import com.tencent.supersonic.common.pojo.DateConf.DateMode; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; +import com.tencent.supersonic.common.util.jsqlparser.FieldExpression; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.knowledge.service.SchemaService; @@ -47,7 +47,7 @@ public class ParserInfoServiceImpl implements ParseInfoService { return; } - List expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL); + List expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL); //set dataInfo try { if (!CollectionUtils.isEmpty(expressions)) { @@ -112,9 +112,9 @@ public class ParserInfoServiceImpl implements ParseInfoService { private List getDimensionFilter(Map fieldNameToElement, - List filterExpressions) { + List fieldExpressions) { List result = Lists.newArrayList(); - for (FilterExpression expression : filterExpressions) { + for (FieldExpression expression : fieldExpressions) { QueryFilter dimensionFilter = new QueryFilter(); dimensionFilter.setValue(expression.getFieldValue()); SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName()); @@ -133,8 +133,8 @@ public class ParserInfoServiceImpl implements ParseInfoService { return result; } - private DateConf getDateInfo(List filterExpressions) { - List dateExpressions = filterExpressions.stream() + private DateConf getDateInfo(List fieldExpressions) { + List dateExpressions = fieldExpressions.stream() .filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName())) .collect(Collectors.toList()); if (CollectionUtils.isEmpty(dateExpressions)) { @@ -142,7 +142,7 @@ public class ParserInfoServiceImpl implements ParseInfoService { } DateConf dateInfo = new DateConf(); dateInfo.setDateMode(DateMode.BETWEEN); - FilterExpression firstExpression = dateExpressions.get(0); + FieldExpression firstExpression = dateExpressions.get(0); FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator()); if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) { @@ -168,12 +168,12 @@ public class ParserInfoServiceImpl implements ParseInfoService { return dateInfo; } - private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator, + private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator, FilterOperatorEnum... operatorEnums) { return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue())); } - private boolean hasSecondDate(List dateExpressions) { + private boolean hasSecondDate(List dateExpressions) { return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue()); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 53d42ce41..cf64afd05 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -48,7 +48,7 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; +import com.tencent.supersonic.common.util.jsqlparser.FieldExpression; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; @@ -296,8 +296,8 @@ public class QueryServiceImpl implements QueryService { String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL(); log.info("correctorSql before replacing:{}", correctorSql); // get where filter and having filter - List whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql); - List havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql); + List whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql); + List havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql); List addWhereConditions = new ArrayList<>(); List addHavingConditions = new ArrayList<>(); Set removeWhereFieldNames = new HashSet<>(); @@ -350,7 +350,7 @@ public class QueryServiceImpl implements QueryService { private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo, Map> filedNameToValueMap, - List filterExpressionList, + List fieldExpressionList, List addConditions, Set removeFieldNames) { if (Objects.isNull(queryData.getDateInfo())) { @@ -364,12 +364,12 @@ public class QueryServiceImpl implements QueryService { } // startDate equals to endDate if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) { - for (FilterExpression filterExpression : filterExpressionList) { - if (TimeDimensionEnum.DAY.getChName().equals(filterExpression.getFieldName())) { + for (FieldExpression fieldExpression : fieldExpressionList) { + if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) { //sql where condition exists 'equals' operator about date,just replace - if (filterExpression.getOperator().equals(FilterOperatorEnum.EQUALS)) { - dateField = filterExpression.getFieldName(); - map.put(filterExpression.getFieldValue().toString(), + if (fieldExpression.getOperator().equals(FilterOperatorEnum.EQUALS)) { + dateField = fieldExpression.getFieldName(); + map.put(fieldExpression.getFieldValue().toString(), queryData.getDateInfo().getStartDate()); filedNameToValueMap.put(dateField, map); } else { @@ -386,23 +386,23 @@ public class QueryServiceImpl implements QueryService { } } } else { - for (FilterExpression filterExpression : filterExpressionList) { - if (TimeDimensionEnum.DAY.getChName().equals(filterExpression.getFieldName())) { - dateField = filterExpression.getFieldName(); + for (FieldExpression fieldExpression : fieldExpressionList) { + if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) { + dateField = fieldExpression.getFieldName(); //just replace - if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(filterExpression.getOperator()) - || FilterOperatorEnum.GREATER_THAN.getValue().equals(filterExpression.getOperator())) { - map.put(filterExpression.getFieldValue().toString(), + if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(fieldExpression.getOperator()) + || FilterOperatorEnum.GREATER_THAN.getValue().equals(fieldExpression.getOperator())) { + map.put(fieldExpression.getFieldValue().toString(), queryData.getDateInfo().getStartDate()); } - if (FilterOperatorEnum.MINOR_THAN_EQUALS.getValue().equals(filterExpression.getOperator()) - || FilterOperatorEnum.MINOR_THAN.getValue().equals(filterExpression.getOperator())) { - map.put(filterExpression.getFieldValue().toString(), + if (FilterOperatorEnum.MINOR_THAN_EQUALS.getValue().equals(fieldExpression.getOperator()) + || FilterOperatorEnum.MINOR_THAN.getValue().equals(fieldExpression.getOperator())) { + map.put(fieldExpression.getFieldValue().toString(), queryData.getDateInfo().getEndDate()); } filedNameToValueMap.put(dateField, map); // first remove,then add - if (FilterOperatorEnum.EQUALS.getValue().equals(filterExpression.getOperator())) { + if (FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator())) { removeFieldNames.add(TimeDimensionEnum.DAY.getChName()); GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions); @@ -425,7 +425,7 @@ public class QueryServiceImpl implements QueryService { addConditions.add(comparisonExpression); } - private void updateFilters(List filterExpressionList, + private void updateFilters(List fieldExpressionList, Set metricFilters, Set contextMetricFilters, List addConditions, @@ -434,9 +434,9 @@ public class QueryServiceImpl implements QueryService { return; } for (QueryFilter dslQueryFilter : metricFilters) { - for (FilterExpression filterExpression : filterExpressionList) { - if (filterExpression.getFieldName() != null - && filterExpression.getFieldName().contains(dslQueryFilter.getName())) { + for (FieldExpression fieldExpression : fieldExpressionList) { + if (fieldExpression.getFieldName() != null + && fieldExpression.getFieldName().contains(dslQueryFilter.getName())) { removeFieldNames.add(dslQueryFilter.getName()); if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) { EqualsTo equalsTo = new EqualsTo(); diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseServiceTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseServiceTest.java new file mode 100644 index 000000000..c42911962 --- /dev/null +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/parser/llm/s2sql/LLMResponseServiceTest.java @@ -0,0 +1,51 @@ +package com.tencent.supersonic.chat.parser.llm.s2sql; + +import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp; +import java.util.HashMap; +import java.util.Map; +import org.junit.Assert; +import org.junit.jupiter.api.Test; + +class LLMResponseServiceTest { + + @Test + void deduplicationSqlWeight() { + String sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; + String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; + + LLMResp llmResp = new LLMResp(); + Map sqlWeight = new HashMap<>(); + sqlWeight.put(sql1, 0.2D); + sqlWeight.put(sql2, 0.8D); + llmResp.setSqlWeight(sqlWeight); + LLMResponseService llmResponseService = new LLMResponseService(); + Map deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp); + + Assert.assertEquals(deduplicationSqlWeight.size(), 1); + + sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; + sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; + + LLMResp llmResp2 = new LLMResp(); + Map sqlWeight2 = new HashMap<>(); + sqlWeight2.put(sql1, 0.2D); + sqlWeight2.put(sql2, 0.8D); + llmResp2.setSqlWeight(sqlWeight2); + deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp2); + + Assert.assertEquals(deduplicationSqlWeight.size(), 1); + + sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; + sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; + + LLMResp llmResp3 = new LLMResp(); + Map sqlWeight3 = new HashMap<>(); + sqlWeight3.put(sql1, 0.2D); + sqlWeight3.put(sql2, 0.8D); + llmResp3.setSqlWeight(sqlWeight3); + deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp3); + + Assert.assertEquals(deduplicationSqlWeight.size(), 2); + + } +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldAndValueAcquireVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldAndValueAcquireVisitor.java index 13c42a677..fcdeb5552 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldAndValueAcquireVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldAndValueAcquireVisitor.java @@ -28,38 +28,38 @@ import org.apache.commons.collections.CollectionUtils; public class FieldAndValueAcquireVisitor extends ExpressionVisitorAdapter { - private Set filterExpressions; + private Set fieldExpressions; - public FieldAndValueAcquireVisitor(Set filterExpressions) { - this.filterExpressions = filterExpressions; + public FieldAndValueAcquireVisitor(Set fieldExpressions) { + this.fieldExpressions = fieldExpressions; } public void visit(LikeExpression expr) { Expression leftExpression = expr.getLeftExpression(); Expression rightExpression = expr.getRightExpression(); - FilterExpression filterExpression = new FilterExpression(); + FieldExpression fieldExpression = new FieldExpression(); String columnName = null; if (leftExpression instanceof Column) { Column column = (Column) leftExpression; columnName = column.getColumnName(); - filterExpression.setFieldName(columnName); + fieldExpression.setFieldName(columnName); } - filterExpression.setFieldValue(getFieldValue(rightExpression)); - filterExpression.setOperator(expr.getStringExpression()); - filterExpressions.add(filterExpression); + fieldExpression.setFieldValue(getFieldValue(rightExpression)); + fieldExpression.setOperator(expr.getStringExpression()); + fieldExpressions.add(fieldExpression); } public void visit(InExpression expr) { - FilterExpression filterExpression = new FilterExpression(); + FieldExpression fieldExpression = new FieldExpression(); Expression leftExpression = expr.getLeftExpression(); if (!(leftExpression instanceof Column)) { return; } - filterExpression.setFieldName(((Column) leftExpression).getColumnName()); - filterExpression.setOperator(JsqlConstants.IN); + fieldExpression.setFieldName(((Column) leftExpression).getColumnName()); + fieldExpression.setOperator(JsqlConstants.IN); ItemsList rightItemsList = expr.getRightItemsList(); - filterExpression.setFieldValue(rightItemsList); + fieldExpression.setFieldValue(rightItemsList); List result = new ArrayList<>(); if (rightItemsList instanceof ExpressionList) { ExpressionList rightExpressionList = (ExpressionList) rightItemsList; @@ -70,78 +70,78 @@ public class FieldAndValueAcquireVisitor extends ExpressionVisitorAdapter { } } } - filterExpression.setFieldValue(result); - filterExpressions.add(filterExpression); + fieldExpression.setFieldValue(result); + fieldExpressions.add(fieldExpression); } @Override public void visit(MinorThan expr) { - FilterExpression filterExpression = getFilterExpression(expr); - filterExpressions.add(filterExpression); + FieldExpression fieldExpression = getFilterExpression(expr); + fieldExpressions.add(fieldExpression); } @Override public void visit(EqualsTo expr) { - FilterExpression filterExpression = getFilterExpression(expr); - filterExpressions.add(filterExpression); + FieldExpression fieldExpression = getFilterExpression(expr); + fieldExpressions.add(fieldExpression); } @Override public void visit(MinorThanEquals expr) { - FilterExpression filterExpression = getFilterExpression(expr); - filterExpressions.add(filterExpression); + FieldExpression fieldExpression = getFilterExpression(expr); + fieldExpressions.add(fieldExpression); } @Override public void visit(GreaterThan expr) { - FilterExpression filterExpression = getFilterExpression(expr); - filterExpressions.add(filterExpression); + FieldExpression fieldExpression = getFilterExpression(expr); + fieldExpressions.add(fieldExpression); } @Override public void visit(GreaterThanEquals expr) { - FilterExpression filterExpression = getFilterExpression(expr); - filterExpressions.add(filterExpression); + FieldExpression fieldExpression = getFilterExpression(expr); + fieldExpressions.add(fieldExpression); } - private FilterExpression getFilterExpression(ComparisonOperator expr) { + private FieldExpression getFilterExpression(ComparisonOperator expr) { Expression leftExpression = expr.getLeftExpression(); Expression rightExpression = expr.getRightExpression(); - FilterExpression filterExpression = new FilterExpression(); + FieldExpression fieldExpression = new FieldExpression(); String columnName = null; if (leftExpression instanceof Column) { Column column = (Column) leftExpression; columnName = column.getColumnName(); - filterExpression.setFieldName(columnName); + fieldExpression.setFieldName(columnName); } if (leftExpression instanceof Function) { Function leftExpressionFunction = (Function) leftExpression; Column field = getColumn(leftExpressionFunction); if (Objects.isNull(field)) { - return filterExpression; + return fieldExpression; } String functionName = leftExpressionFunction.getName().toUpperCase(); - filterExpression.setFieldName(field.getColumnName()); - filterExpression.setFunction(functionName); - filterExpression.setOperator(expr.getStringExpression()); + fieldExpression.setFieldName(field.getColumnName()); + fieldExpression.setFunction(functionName); + fieldExpression.setOperator(expr.getStringExpression()); //deal with DAY/WEEK function List collect = Arrays.stream(DatePeriodEnum.values()).collect(Collectors.toList()); DatePeriodEnum periodEnum = DatePeriodEnum.get(functionName); if (Objects.nonNull(periodEnum) && collect.contains(periodEnum)) { - filterExpression.setFieldValue(getFieldValue(rightExpression) + periodEnum.getChName()); - return filterExpression; + fieldExpression.setFieldValue(getFieldValue(rightExpression) + periodEnum.getChName()); + return fieldExpression; } else { //deal with aggregate function - filterExpression.setFieldValue(getFieldValue(rightExpression)); - return filterExpression; + fieldExpression.setFieldValue(getFieldValue(rightExpression)); + return fieldExpression; } } - filterExpression.setFieldValue(getFieldValue(rightExpression)); - filterExpression.setOperator(expr.getStringExpression()); - return filterExpression; + fieldExpression.setFieldValue(getFieldValue(rightExpression)); + fieldExpression.setOperator(expr.getStringExpression()); + return fieldExpression; } private Column getColumn(Function leftExpressionFunction) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FilterExpression.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldExpression.java similarity index 86% rename from common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FilterExpression.java rename to common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldExpression.java index 3622fe326..6e9c746eb 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FilterExpression.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldExpression.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.common.util.jsqlparser; import lombok.Data; @Data -public class FilterExpression { +public class FieldExpression { private String operator; diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByAcquireVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByAcquireVisitor.java index 773110861..9a3fb2bb1 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByAcquireVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByAcquireVisitor.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.common.util.jsqlparser; +import com.tencent.supersonic.common.pojo.Constants; import java.util.List; import java.util.Set; import net.sf.jsqlparser.expression.Expression; @@ -10,27 +11,34 @@ import net.sf.jsqlparser.statement.select.OrderByVisitorAdapter; public class OrderByAcquireVisitor extends OrderByVisitorAdapter { - private Set fields; + private Set fields; - public OrderByAcquireVisitor(Set fields) { + public OrderByAcquireVisitor(Set fields) { this.fields = fields; } @Override public void visit(OrderByElement orderBy) { Expression expression = orderBy.getExpression(); + FieldExpression fieldExpression = new FieldExpression(); if (expression instanceof Column) { - fields.add(((Column) expression).getColumnName()); + fieldExpression.setFieldName(((Column) expression).getColumnName()); } if (expression instanceof Function) { Function function = (Function) expression; List expressions = function.getParameters().getExpressions(); for (Expression column : expressions) { if (column instanceof Column) { - fields.add(((Column) column).getColumnName()); + fieldExpression.setFieldName(((Column) column).getColumnName()); } } } + String operator = Constants.ASC_UPPER; + if (!orderBy.isAsc()) { + operator = Constants.DESC_UPPER; + } + fieldExpression.setOperator(operator); + fields.add(fieldExpression); super.visit(orderBy); } } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByExpression.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByExpression.java new file mode 100644 index 000000000..970c37df6 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/OrderByExpression.java @@ -0,0 +1,16 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import lombok.Data; + +@Data +public class OrderByExpression { + + private String operator; + + private String fieldName; + + private Object fieldValue; + + private String function; + +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserEqualHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserEqualHelper.java new file mode 100644 index 000000000..928996512 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserEqualHelper.java @@ -0,0 +1,67 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import java.util.List; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; + +/** + * Sql Parser equal Helper + */ +@Slf4j +public class SqlParserEqualHelper { + + /** + * determine if two SQL statements are equal. + * + * @param thisSql + * @param otherSql + * @return + */ + public static boolean equals(String thisSql, String otherSql) { + //1. select fields + List thisSelectFields = SqlParserSelectHelper.getSelectFields(thisSql); + List otherSelectFields = SqlParserSelectHelper.getSelectFields(otherSql); + + if (!CollectionUtils.isEqualCollection(thisSelectFields, otherSelectFields)) { + return false; + } + + //2. all fields + List thisAllFields = SqlParserSelectHelper.getAllFields(thisSql); + List otherAllFields = SqlParserSelectHelper.getAllFields(otherSql); + + if (!CollectionUtils.isEqualCollection(thisAllFields, otherAllFields)) { + return false; + } + + //3. where + List thisFieldExpressions = SqlParserSelectHelper.getFilterExpression(thisSql); + List otherFieldExpressions = SqlParserSelectHelper.getFilterExpression(otherSql); + + if (!CollectionUtils.isEqualCollection(thisFieldExpressions, otherFieldExpressions)) { + return false; + } + //4. tableName + if (!SqlParserSelectHelper.getDbTableName(thisSql) + .equalsIgnoreCase(SqlParserSelectHelper.getDbTableName(otherSql))) { + return false; + } + //5. having + List thisHavingExpressions = SqlParserSelectHelper.getHavingExpressions(thisSql); + List otherHavingExpressions = SqlParserSelectHelper.getHavingExpressions(otherSql); + + if (!CollectionUtils.isEqualCollection(thisHavingExpressions, otherHavingExpressions)) { + return false; + } + //6. orderBy + List thisOrderByExpressions = SqlParserSelectHelper.getOrderByExpressions(thisSql); + List otherOrderByExpressions = SqlParserSelectHelper.getOrderByExpressions(otherSql); + + if (!CollectionUtils.isEqualCollection(thisOrderByExpressions, otherOrderByExpressions)) { + return false; + } + return true; + } + +} + diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java index 80ccf957f..165e1e7cf 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelper.java @@ -5,6 +5,7 @@ import java.util.HashSet; import java.util.List; import java.util.Objects; import java.util.Set; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; @@ -40,12 +41,12 @@ import org.springframework.util.CollectionUtils; @Slf4j public class SqlParserSelectHelper { - public static List getFilterExpression(String sql) { + public static List getFilterExpression(String sql) { PlainSelect plainSelect = getPlainSelect(sql); if (Objects.isNull(plainSelect)) { return new ArrayList<>(); } - Set result = new HashSet<>(); + Set result = new HashSet<>(); Expression where = plainSelect.getWhere(); if (Objects.nonNull(where)) { where.accept(new FieldAndValueAcquireVisitor(result)); @@ -208,12 +209,12 @@ public class SqlParserSelectHelper { return null; } - public static List getWhereExpressions(String sql) { + public static List getWhereExpressions(String sql) { PlainSelect plainSelect = getPlainSelect(sql); if (Objects.isNull(plainSelect)) { return new ArrayList<>(); } - Set result = new HashSet<>(); + Set result = new HashSet<>(); Expression where = plainSelect.getWhere(); if (Objects.nonNull(where)) { where.accept(new FieldAndValueAcquireVisitor(result)); @@ -221,12 +222,12 @@ public class SqlParserSelectHelper { return new ArrayList<>(result); } - public static List getHavingExpressions(String sql) { + public static List getHavingExpressions(String sql) { PlainSelect plainSelect = getPlainSelect(sql); if (Objects.isNull(plainSelect)) { return new ArrayList<>(); } - Set result = new HashSet<>(); + Set result = new HashSet<>(); Expression having = plainSelect.getHaving(); if (Objects.nonNull(having)) { having.accept(new FieldAndValueAcquireVisitor(result)); @@ -244,13 +245,31 @@ public class SqlParserSelectHelper { return new ArrayList<>(result); } - private static void getOrderByFields(PlainSelect plainSelect, Set result) { + private static Set getOrderByFields(PlainSelect plainSelect) { + Set result = new HashSet<>(); List orderByElements = plainSelect.getOrderByElements(); if (!CollectionUtils.isEmpty(orderByElements)) { for (OrderByElement orderByElement : orderByElements) { orderByElement.accept(new OrderByAcquireVisitor(result)); } } + return result; + } + + private static void getOrderByFields(PlainSelect plainSelect, Set result) { + Set orderByFieldExpressions = getOrderByFields(plainSelect); + Set collect = orderByFieldExpressions.stream() + .map(fieldExpression -> fieldExpression.getFieldName()) + .collect(Collectors.toSet()); + result.addAll(collect); + } + + public static List getOrderByExpressions(String sql) { + PlainSelect plainSelect = getPlainSelect(sql); + if (Objects.isNull(plainSelect)) { + return new ArrayList<>(); + } + return new ArrayList<>(getOrderByFields(plainSelect)); } public static List getGroupByFields(String sql) { diff --git a/common/src/test/java/com/tencent/supersonic/common/util/DateUtilsTest.java b/common/src/test/java/com/tencent/supersonic/common/util/DateUtilsTest.java index 48f6fe2c1..fd9703533 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/DateUtilsTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/DateUtilsTest.java @@ -77,8 +77,7 @@ class DateUtilsTest { String startDate = "2023-07-01"; String endDate = "2023-10-01"; List actualDateList = DateUtils.getDateList(startDate, endDate, Constants.MONTH); - List expectedDateList = Lists.newArrayList("2023-07-01", "2023-08-01", - "2023-09-01", "2023-10-01"); + List expectedDateList = Lists.newArrayList("2023-07", "2023-08", "2023-09", "2023-10"); Assertions.assertEquals(actualDateList, expectedDateList); } } \ No newline at end of file diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserEqualHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserEqualHelperTest.java new file mode 100644 index 000000000..3995a0f55 --- /dev/null +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserEqualHelperTest.java @@ -0,0 +1,41 @@ +package com.tencent.supersonic.common.util.jsqlparser; + + +import cn.hutool.core.lang.Assert; +import org.junit.jupiter.api.Test; + +/** + * @author lex luo + * @date 2023/11/15 15:04 + */ +class SqlParserEqualHelperTest { + + @Test + void testEquals() { + String sql1 = "SELECT * FROM table1 WHERE column1 = 1 AND column2 = 2"; + String sql2 = "SELECT * FROM table1 WHERE column2 = 2 AND column1 = 1"; + Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), true); + + sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; + sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; + Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), true); + + sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a"; + + sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a"; + + Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), true); + + sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a"; + + sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a"; + + Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), true); + + sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a"; + sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; + Assert.equals(SqlParserEqualHelper.equals(sql1, sql2), false); + + + } +} \ No newline at end of file diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java index f45ac2cb8..c1cb5bbec 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java @@ -18,106 +18,106 @@ class SqlParserSelectHelperTest { "select 用户名, 访问次数 from 超音数 where 用户名 in ('alice', 'lucy')"); System.out.println(selectStatement); - List filterExpression = SqlParserSelectHelper.getFilterExpression( + List fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, user_id, field_a FROM s2 WHERE " + "sys_imp_date = '2023-08-08' AND YEAR(publish_date) = 2023 " + " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " + " AND YEAR(publish_date) = 2023 " + " AND MONTH(publish_date) = 8" + " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08'" + " AND YEAR(publish_date) = 2023 " + " AND MONTH(publish_date) = 8 AND DAY(publish_date) =20 " + " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " + " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " + "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " + "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, user_id, field_a FROM s2 WHERE " + "user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, user_id, field_a FROM s2 WHERE " + "user_id = 'alice' AND publish_date > 10000 ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, user_id, field_a FROM s2 WHERE " + "user_id like '%alice%' AND publish_date > 10000 ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, pv FROM s2 WHERE " + "user_id like '%alice%' AND publish_date > 10000 " + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, pv FROM s2 WHERE " + "(user_id like '%alice%' AND publish_date > 10000) and sys_imp_date = '2023-08-08' " + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, pv FROM s2 WHERE " + "(user_id like '%alice%' AND publish_date > 10000) and song_name in " + "('七里香','晴天') and sys_imp_date = '2023-08-08' " + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, pv FROM s2 WHERE " + "(user_id like '%alice%' AND publish_date > 10000) and song_name in (1,2) " + "and sys_imp_date = '2023-08-08' " + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression( + fieldExpression = SqlParserSelectHelper.getFilterExpression( "SELECT department, pv FROM s2 WHERE " + "(user_id like '%alice%' AND publish_date > 10000) and 1 in (1) " + "and sys_imp_date = '2023-08-08' " + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); - System.out.println(filterExpression); + System.out.println(fieldExpression); - filterExpression = SqlParserSelectHelper.getFilterExpression("SELECT sum(销量) / (SELECT sum(销量) FROM 营销月模型 " + fieldExpression = SqlParserSelectHelper.getFilterExpression("SELECT sum(销量) / (SELECT sum(销量) FROM 营销月模型 " + "WHERE MONTH(数据日期) = 9) FROM 营销月模型 WHERE 国家中文名 = '肯尼亚' AND MONTH(数据日期) = 9"); - System.out.println(filterExpression); + System.out.println(fieldExpression); } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DimValueAspect.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DimValueAspect.java index 017f94cdb..050ec0873 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DimValueAspect.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DimValueAspect.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.semantic.query.utils; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; +import com.tencent.supersonic.common.util.jsqlparser.FieldExpression; import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.semantic.api.model.pojo.DimValueMap; @@ -60,11 +60,11 @@ public class DimValueAspect { String sql = queryS2SQLReq.getSql(); log.info("correctorSql before replacing:{}", sql); // if dimensionvalue is alias,consider the true dimensionvalue. - List filterExpressionList = SqlParserSelectHelper.getWhereExpressions(sql); + List fieldExpressionList = SqlParserSelectHelper.getWhereExpressions(sql); List dimensions = dimensionService.getDimensions(metaFilter); Set fieldNames = dimensions.stream().map(o -> o.getName()).collect(Collectors.toSet()); Map> filedNameToValueMap = new HashMap<>(); - filterExpressionList.stream().forEach(expression -> { + fieldExpressionList.stream().forEach(expression -> { if (fieldNames.contains(expression.getFieldName())) { dimensions.stream().forEach(dimension -> { if (expression.getFieldName().equals(dimension.getName()) @@ -98,7 +98,7 @@ public class DimValueAspect { return queryResultWithColumns; } - public void replaceInCondition(FilterExpression expression, DimensionResp dimension, + public void replaceInCondition(FieldExpression expression, DimensionResp dimension, Map> filedNameToValueMap) { if (expression.getOperator().equals(FilterOperatorEnum.IN.getValue())) { String fieldValue = JsonUtil.toString(expression.getFieldValue()); diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java index 9fe29d396..18a03dea4 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java @@ -14,7 +14,7 @@ import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.util.DateModeUtils; import com.tencent.supersonic.common.util.SqlFilterUtils; import com.tencent.supersonic.common.util.StringUtil; -import com.tencent.supersonic.common.util.jsqlparser.FilterExpression; +import com.tencent.supersonic.common.util.jsqlparser.FieldExpression; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; @@ -424,13 +424,13 @@ public class QueryStructUtils { } public DateConf getDateConfBySql(String sql) { - List filterExpressions = SqlParserSelectHelper.getFilterExpression(sql); - if (!CollectionUtils.isEmpty(filterExpressions)) { + List fieldExpressions = SqlParserSelectHelper.getFilterExpression(sql); + if (!CollectionUtils.isEmpty(fieldExpressions)) { Set dateList = new HashSet<>(); String startDate = ""; String endDate = ""; String period = ""; - for (FilterExpression f : filterExpressions) { + for (FieldExpression f : fieldExpressions) { if (Objects.isNull(f.getFieldName()) || !internalCols.contains(f.getFieldName().toLowerCase())) { continue; }