diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java index 8bacfe1d5..3ae09c4d7 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java @@ -1,8 +1,6 @@ package com.tencent.supersonic.headless.chat.corrector; -import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.util.StringUtil; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper; import com.tencent.supersonic.headless.api.pojo.SchemaElement; @@ -11,6 +9,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.utils.QueryFilterParser; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; @@ -22,7 +21,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; /** * Perform SQL corrections on the "Where" section in S2SQL. @@ -38,7 +36,7 @@ public class WhereCorrector extends BaseSemanticCorrector { updateFieldValueByTechName(queryContext, semanticParseInfo); } - private void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + protected void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { String queryFilter = getQueryFilter(queryContext.getQueryFilters()); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); @@ -60,14 +58,7 @@ public class WhereCorrector extends BaseSemanticCorrector { if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) { return null; } - return queryFilters.getFilters().stream() - .map(filter -> { - String bizNameWrap = StringUtil.getSpaceWrap(filter.getName()); - String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue()); - String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString()); - return bizNameWrap + operatorWrap + valueWrap; - }) - .collect(Collectors.joining(Constants.AND_UPPER)); + return QueryFilterParser.parse(queryFilters); } private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParser.java new file mode 100644 index 000000000..4aab2bf67 --- /dev/null +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParser.java @@ -0,0 +1,72 @@ +package com.tencent.supersonic.headless.chat.utils; + +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; +import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; +import lombok.extern.slf4j.Slf4j; + +import java.util.List; +import java.util.stream.Collectors; + +@Slf4j +public class QueryFilterParser { + + public static String parse(QueryFilters queryFilters) { + try { + List conditions = queryFilters.getFilters().stream() + .map(QueryFilterParser::parseFilter) + .collect(Collectors.toList()); + return String.join(" AND ", conditions); + } catch (Exception e) { + log.error("", e); + } + return null; + } + + private static String parseFilter(QueryFilter filter) { + String column = filter.getName(); + FilterOperatorEnum operator = filter.getOperator(); + Object value = filter.getValue(); + + switch (operator) { + case IN: + case NOT_IN: + return column + " " + operator.getValue() + " (" + parseList(value) + ")"; + case BETWEEN: + if (value instanceof List && ((List) value).size() == 2) { + List values = (List) value; + return column + " BETWEEN " + formatValue(values.get(0)) + " AND " + formatValue(values.get(1)); + } + throw new IllegalArgumentException("BETWEEN operator requires a list of two values"); + case IS_NULL: + case IS_NOT_NULL: + return column + " " + operator.getValue(); + case EXISTS: + return "EXISTS (" + value + ")"; + case SQL_PART: + return value.toString(); + default: + return column + " " + operator.getValue() + " " + formatValue(value); + } + } + + private static String parseList(Object value) { + if (value instanceof List) { + return ((List) value).stream() + .map(QueryFilterParser::formatValue) + .collect(Collectors.joining(", ")); + } + throw new IllegalArgumentException("IN and NOT IN operators require a list of values"); + } + + private static String formatValue(Object value) { + if (value instanceof String) { + return "'" + value + "'"; + } else if (value instanceof Number) { + return value.toString(); + } else if (value instanceof Boolean) { + return (Boolean) value ? "TRUE" : "FALSE"; + } + throw new IllegalArgumentException("Unsupported value type: " + value.getClass()); + } +} \ No newline at end of file diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java new file mode 100644 index 000000000..1b385f0c4 --- /dev/null +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java @@ -0,0 +1,63 @@ +package com.tencent.supersonic.headless.chat.corrector; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.SqlInfo; +import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; +import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; +import com.tencent.supersonic.headless.chat.QueryContext; +import org.junit.Assert; +import org.junit.jupiter.api.Test; + + +class WhereCorrectorTest { + + @Test + void addQueryFilter() { + SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); + SqlInfo sqlInfo = new SqlInfo(); + String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; + sqlInfo.setCorrectS2SQL(sql); + semanticParseInfo.setSqlInfo(sqlInfo); + + QueryContext queryContext = new QueryContext(); + + QueryFilter filter1 = new QueryFilter(); + filter1.setName("age"); + filter1.setOperator(FilterOperatorEnum.GREATER_THAN); + filter1.setValue(30); + + QueryFilter filter2 = new QueryFilter(); + filter2.setName("name"); + filter2.setOperator(FilterOperatorEnum.LIKE); + filter2.setValue("John%"); + + QueryFilter filter3 = new QueryFilter(); + filter3.setName("id"); + filter3.setOperator(FilterOperatorEnum.IN); + filter3.setValue(Lists.newArrayList(1, 2, 3, 4)); + + QueryFilter filter4 = new QueryFilter(); + filter4.setName("status"); + filter4.setOperator(FilterOperatorEnum.NOT_IN); + filter4.setValue(Lists.newArrayList("inactive", "deleted")); + + QueryFilters queryFilters = new QueryFilters(); + queryFilters.getFilters().add(filter1); + queryFilters.getFilters().add(filter2); + queryFilters.getFilters().add(filter3); + queryFilters.getFilters().add(filter4); + queryContext.setQueryFilters(queryFilters); + + WhereCorrector whereCorrector = new WhereCorrector(); + whereCorrector.addQueryFilter(queryContext, semanticParseInfo); + + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + + Assert.assertEquals(correctS2SQL, "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE " + + "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND " + + "name LIKE 'John%' AND id IN (1, 2, 3, 4) AND status GROUP BY 维度1"); + } +} \ No newline at end of file diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParserTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParserTest.java new file mode 100644 index 000000000..afb6c0f73 --- /dev/null +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParserTest.java @@ -0,0 +1,48 @@ +package com.tencent.supersonic.headless.chat.utils; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; +import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; +import org.junit.Assert; +import org.junit.jupiter.api.Test; + + +class QueryFilterParserTest { + + @Test + void parse() { + // Example usage + QueryFilter filter1 = new QueryFilter(); + filter1.setName("age"); + filter1.setOperator(FilterOperatorEnum.GREATER_THAN); + filter1.setValue(30); + + QueryFilter filter2 = new QueryFilter(); + filter2.setName("name"); + filter2.setOperator(FilterOperatorEnum.LIKE); + filter2.setValue("John%"); + + QueryFilter filter3 = new QueryFilter(); + filter3.setName("id"); + filter3.setOperator(FilterOperatorEnum.IN); + filter3.setValue(Lists.newArrayList(1, 2, 3, 4)); + + QueryFilter filter4 = new QueryFilter(); + filter4.setName("status"); + filter4.setOperator(FilterOperatorEnum.NOT_IN); + filter4.setValue(Lists.newArrayList("inactive", "deleted")); + + QueryFilters queryFilters = new QueryFilters(); + queryFilters.getFilters().add(filter1); + queryFilters.getFilters().add(filter2); + queryFilters.getFilters().add(filter3); + queryFilters.getFilters().add(filter4); + + String parse = QueryFilterParser.parse(queryFilters); + + Assert.assertEquals(parse, "age > 30 AND name LIKE 'John%' AND id IN (1, 2, 3, 4)" + + " AND status NOT_IN ('inactive', 'deleted')"); + } + +} \ No newline at end of file