mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
(improvement)(chat) Support additional operators in QueryFilter correction. (#1347)
This commit is contained in:
@@ -1,8 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.chat.corrector;
|
package com.tencent.supersonic.headless.chat.corrector;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
@@ -11,6 +9,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||||
|
import com.tencent.supersonic.headless.chat.utils.QueryFilterParser;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
@@ -22,7 +21,6 @@ import java.util.HashMap;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform SQL corrections on the "Where" section in S2SQL.
|
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||||
@@ -38,7 +36,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
protected void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
String queryFilter = getQueryFilter(queryContext.getQueryFilters());
|
String queryFilter = getQueryFilter(queryContext.getQueryFilters());
|
||||||
|
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
@@ -60,14 +58,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
return queryFilters.getFilters().stream()
|
return QueryFilterParser.parse(queryFilters);
|
||||||
.map(filter -> {
|
|
||||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getName());
|
|
||||||
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
|
||||||
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
|
||||||
return bizNameWrap + operatorWrap + valueWrap;
|
|
||||||
})
|
|
||||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|||||||
@@ -0,0 +1,72 @@
|
|||||||
|
package com.tencent.supersonic.headless.chat.utils;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class QueryFilterParser {
|
||||||
|
|
||||||
|
public static String parse(QueryFilters queryFilters) {
|
||||||
|
try {
|
||||||
|
List<String> conditions = queryFilters.getFilters().stream()
|
||||||
|
.map(QueryFilterParser::parseFilter)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
return String.join(" AND ", conditions);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("", e);
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String parseFilter(QueryFilter filter) {
|
||||||
|
String column = filter.getName();
|
||||||
|
FilterOperatorEnum operator = filter.getOperator();
|
||||||
|
Object value = filter.getValue();
|
||||||
|
|
||||||
|
switch (operator) {
|
||||||
|
case IN:
|
||||||
|
case NOT_IN:
|
||||||
|
return column + " " + operator.getValue() + " (" + parseList(value) + ")";
|
||||||
|
case BETWEEN:
|
||||||
|
if (value instanceof List && ((List<?>) value).size() == 2) {
|
||||||
|
List<?> values = (List<?>) value;
|
||||||
|
return column + " BETWEEN " + formatValue(values.get(0)) + " AND " + formatValue(values.get(1));
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("BETWEEN operator requires a list of two values");
|
||||||
|
case IS_NULL:
|
||||||
|
case IS_NOT_NULL:
|
||||||
|
return column + " " + operator.getValue();
|
||||||
|
case EXISTS:
|
||||||
|
return "EXISTS (" + value + ")";
|
||||||
|
case SQL_PART:
|
||||||
|
return value.toString();
|
||||||
|
default:
|
||||||
|
return column + " " + operator.getValue() + " " + formatValue(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String parseList(Object value) {
|
||||||
|
if (value instanceof List) {
|
||||||
|
return ((List<?>) value).stream()
|
||||||
|
.map(QueryFilterParser::formatValue)
|
||||||
|
.collect(Collectors.joining(", "));
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("IN and NOT IN operators require a list of values");
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String formatValue(Object value) {
|
||||||
|
if (value instanceof String) {
|
||||||
|
return "'" + value + "'";
|
||||||
|
} else if (value instanceof Number) {
|
||||||
|
return value.toString();
|
||||||
|
} else if (value instanceof Boolean) {
|
||||||
|
return (Boolean) value ? "TRUE" : "FALSE";
|
||||||
|
}
|
||||||
|
throw new IllegalArgumentException("Unsupported value type: " + value.getClass());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
package com.tencent.supersonic.headless.chat.corrector;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
|
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
|
||||||
|
class WhereCorrectorTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void addQueryFilter() {
|
||||||
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
|
SqlInfo sqlInfo = new SqlInfo();
|
||||||
|
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||||
|
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
|
||||||
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
|
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||||
|
|
||||||
|
QueryContext queryContext = new QueryContext();
|
||||||
|
|
||||||
|
QueryFilter filter1 = new QueryFilter();
|
||||||
|
filter1.setName("age");
|
||||||
|
filter1.setOperator(FilterOperatorEnum.GREATER_THAN);
|
||||||
|
filter1.setValue(30);
|
||||||
|
|
||||||
|
QueryFilter filter2 = new QueryFilter();
|
||||||
|
filter2.setName("name");
|
||||||
|
filter2.setOperator(FilterOperatorEnum.LIKE);
|
||||||
|
filter2.setValue("John%");
|
||||||
|
|
||||||
|
QueryFilter filter3 = new QueryFilter();
|
||||||
|
filter3.setName("id");
|
||||||
|
filter3.setOperator(FilterOperatorEnum.IN);
|
||||||
|
filter3.setValue(Lists.newArrayList(1, 2, 3, 4));
|
||||||
|
|
||||||
|
QueryFilter filter4 = new QueryFilter();
|
||||||
|
filter4.setName("status");
|
||||||
|
filter4.setOperator(FilterOperatorEnum.NOT_IN);
|
||||||
|
filter4.setValue(Lists.newArrayList("inactive", "deleted"));
|
||||||
|
|
||||||
|
QueryFilters queryFilters = new QueryFilters();
|
||||||
|
queryFilters.getFilters().add(filter1);
|
||||||
|
queryFilters.getFilters().add(filter2);
|
||||||
|
queryFilters.getFilters().add(filter3);
|
||||||
|
queryFilters.getFilters().add(filter4);
|
||||||
|
queryContext.setQueryFilters(queryFilters);
|
||||||
|
|
||||||
|
WhereCorrector whereCorrector = new WhereCorrector();
|
||||||
|
whereCorrector.addQueryFilter(queryContext, semanticParseInfo);
|
||||||
|
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
|
||||||
|
Assert.assertEquals(correctS2SQL, "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE "
|
||||||
|
+ "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND "
|
||||||
|
+ "name LIKE 'John%' AND id IN (1, 2, 3, 4) AND status GROUP BY 维度1");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
package com.tencent.supersonic.headless.chat.utils;
|
||||||
|
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
|
||||||
|
class QueryFilterParserTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void parse() {
|
||||||
|
// Example usage
|
||||||
|
QueryFilter filter1 = new QueryFilter();
|
||||||
|
filter1.setName("age");
|
||||||
|
filter1.setOperator(FilterOperatorEnum.GREATER_THAN);
|
||||||
|
filter1.setValue(30);
|
||||||
|
|
||||||
|
QueryFilter filter2 = new QueryFilter();
|
||||||
|
filter2.setName("name");
|
||||||
|
filter2.setOperator(FilterOperatorEnum.LIKE);
|
||||||
|
filter2.setValue("John%");
|
||||||
|
|
||||||
|
QueryFilter filter3 = new QueryFilter();
|
||||||
|
filter3.setName("id");
|
||||||
|
filter3.setOperator(FilterOperatorEnum.IN);
|
||||||
|
filter3.setValue(Lists.newArrayList(1, 2, 3, 4));
|
||||||
|
|
||||||
|
QueryFilter filter4 = new QueryFilter();
|
||||||
|
filter4.setName("status");
|
||||||
|
filter4.setOperator(FilterOperatorEnum.NOT_IN);
|
||||||
|
filter4.setValue(Lists.newArrayList("inactive", "deleted"));
|
||||||
|
|
||||||
|
QueryFilters queryFilters = new QueryFilters();
|
||||||
|
queryFilters.getFilters().add(filter1);
|
||||||
|
queryFilters.getFilters().add(filter2);
|
||||||
|
queryFilters.getFilters().add(filter3);
|
||||||
|
queryFilters.getFilters().add(filter4);
|
||||||
|
|
||||||
|
String parse = QueryFilterParser.parse(queryFilters);
|
||||||
|
|
||||||
|
Assert.assertEquals(parse, "age > 30 AND name LIKE 'John%' AND id IN (1, 2, 3, 4)"
|
||||||
|
+ " AND status NOT_IN ('inactive', 'deleted')");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user