mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(chat) Support additional operators in QueryFilter correction. (#1347)
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
@@ -11,6 +9,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryFilterParser;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
@@ -22,7 +21,6 @@ import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||
@@ -38,7 +36,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
protected void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String queryFilter = getQueryFilter(queryContext.getQueryFilters());
|
||||
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
@@ -60,14 +58,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||
return null;
|
||||
}
|
||||
return queryFilters.getFilters().stream()
|
||||
.map(filter -> {
|
||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getName());
|
||||
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
||||
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
||||
return bizNameWrap + operatorWrap + valueWrap;
|
||||
})
|
||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||
return QueryFilterParser.parse(queryFilters);
|
||||
}
|
||||
|
||||
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
package com.tencent.supersonic.headless.chat.utils;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class QueryFilterParser {
|
||||
|
||||
public static String parse(QueryFilters queryFilters) {
|
||||
try {
|
||||
List<String> conditions = queryFilters.getFilters().stream()
|
||||
.map(QueryFilterParser::parseFilter)
|
||||
.collect(Collectors.toList());
|
||||
return String.join(" AND ", conditions);
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static String parseFilter(QueryFilter filter) {
|
||||
String column = filter.getName();
|
||||
FilterOperatorEnum operator = filter.getOperator();
|
||||
Object value = filter.getValue();
|
||||
|
||||
switch (operator) {
|
||||
case IN:
|
||||
case NOT_IN:
|
||||
return column + " " + operator.getValue() + " (" + parseList(value) + ")";
|
||||
case BETWEEN:
|
||||
if (value instanceof List && ((List<?>) value).size() == 2) {
|
||||
List<?> values = (List<?>) value;
|
||||
return column + " BETWEEN " + formatValue(values.get(0)) + " AND " + formatValue(values.get(1));
|
||||
}
|
||||
throw new IllegalArgumentException("BETWEEN operator requires a list of two values");
|
||||
case IS_NULL:
|
||||
case IS_NOT_NULL:
|
||||
return column + " " + operator.getValue();
|
||||
case EXISTS:
|
||||
return "EXISTS (" + value + ")";
|
||||
case SQL_PART:
|
||||
return value.toString();
|
||||
default:
|
||||
return column + " " + operator.getValue() + " " + formatValue(value);
|
||||
}
|
||||
}
|
||||
|
||||
private static String parseList(Object value) {
|
||||
if (value instanceof List) {
|
||||
return ((List<?>) value).stream()
|
||||
.map(QueryFilterParser::formatValue)
|
||||
.collect(Collectors.joining(", "));
|
||||
}
|
||||
throw new IllegalArgumentException("IN and NOT IN operators require a list of values");
|
||||
}
|
||||
|
||||
private static String formatValue(Object value) {
|
||||
if (value instanceof String) {
|
||||
return "'" + value + "'";
|
||||
} else if (value instanceof Number) {
|
||||
return value.toString();
|
||||
} else if (value instanceof Boolean) {
|
||||
return (Boolean) value ? "TRUE" : "FALSE";
|
||||
}
|
||||
throw new IllegalArgumentException("Unsupported value type: " + value.getClass());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
||||
class WhereCorrectorTest {
|
||||
|
||||
@Test
|
||||
void addQueryFilter() {
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
|
||||
QueryContext queryContext = new QueryContext();
|
||||
|
||||
QueryFilter filter1 = new QueryFilter();
|
||||
filter1.setName("age");
|
||||
filter1.setOperator(FilterOperatorEnum.GREATER_THAN);
|
||||
filter1.setValue(30);
|
||||
|
||||
QueryFilter filter2 = new QueryFilter();
|
||||
filter2.setName("name");
|
||||
filter2.setOperator(FilterOperatorEnum.LIKE);
|
||||
filter2.setValue("John%");
|
||||
|
||||
QueryFilter filter3 = new QueryFilter();
|
||||
filter3.setName("id");
|
||||
filter3.setOperator(FilterOperatorEnum.IN);
|
||||
filter3.setValue(Lists.newArrayList(1, 2, 3, 4));
|
||||
|
||||
QueryFilter filter4 = new QueryFilter();
|
||||
filter4.setName("status");
|
||||
filter4.setOperator(FilterOperatorEnum.NOT_IN);
|
||||
filter4.setValue(Lists.newArrayList("inactive", "deleted"));
|
||||
|
||||
QueryFilters queryFilters = new QueryFilters();
|
||||
queryFilters.getFilters().add(filter1);
|
||||
queryFilters.getFilters().add(filter2);
|
||||
queryFilters.getFilters().add(filter3);
|
||||
queryFilters.getFilters().add(filter4);
|
||||
queryContext.setQueryFilters(queryFilters);
|
||||
|
||||
WhereCorrector whereCorrector = new WhereCorrector();
|
||||
whereCorrector.addQueryFilter(queryContext, semanticParseInfo);
|
||||
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
|
||||
Assert.assertEquals(correctS2SQL, "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE "
|
||||
+ "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND "
|
||||
+ "name LIKE 'John%' AND id IN (1, 2, 3, 4) AND status GROUP BY 维度1");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package com.tencent.supersonic.headless.chat.utils;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
|
||||
class QueryFilterParserTest {
|
||||
|
||||
@Test
|
||||
void parse() {
|
||||
// Example usage
|
||||
QueryFilter filter1 = new QueryFilter();
|
||||
filter1.setName("age");
|
||||
filter1.setOperator(FilterOperatorEnum.GREATER_THAN);
|
||||
filter1.setValue(30);
|
||||
|
||||
QueryFilter filter2 = new QueryFilter();
|
||||
filter2.setName("name");
|
||||
filter2.setOperator(FilterOperatorEnum.LIKE);
|
||||
filter2.setValue("John%");
|
||||
|
||||
QueryFilter filter3 = new QueryFilter();
|
||||
filter3.setName("id");
|
||||
filter3.setOperator(FilterOperatorEnum.IN);
|
||||
filter3.setValue(Lists.newArrayList(1, 2, 3, 4));
|
||||
|
||||
QueryFilter filter4 = new QueryFilter();
|
||||
filter4.setName("status");
|
||||
filter4.setOperator(FilterOperatorEnum.NOT_IN);
|
||||
filter4.setValue(Lists.newArrayList("inactive", "deleted"));
|
||||
|
||||
QueryFilters queryFilters = new QueryFilters();
|
||||
queryFilters.getFilters().add(filter1);
|
||||
queryFilters.getFilters().add(filter2);
|
||||
queryFilters.getFilters().add(filter3);
|
||||
queryFilters.getFilters().add(filter4);
|
||||
|
||||
String parse = QueryFilterParser.parse(queryFilters);
|
||||
|
||||
Assert.assertEquals(parse, "age > 30 AND name LIKE 'John%' AND id IN (1, 2, 3, 4)"
|
||||
+ " AND status NOT_IN ('inactive', 'deleted')");
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user