(improvement)(chat) Support additional operators in QueryFilter correction. (#1347)

This commit is contained in:
lexluo09
2024-07-05 13:20:40 +08:00
committed by GitHub
parent 71954e42a8
commit 097f2f4fe7
4 changed files with 186 additions and 12 deletions

View File

@@ -1,8 +1,6 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
@@ -11,6 +9,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.utils.QueryFilterParser;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
@@ -22,7 +21,6 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* Perform SQL corrections on the "Where" section in S2SQL.
@@ -38,7 +36,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
updateFieldValueByTechName(queryContext, semanticParseInfo);
}
private void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
protected void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(queryContext.getQueryFilters());
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
@@ -60,14 +58,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return null;
}
return queryFilters.getFilters().stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
})
.collect(Collectors.joining(Constants.AND_UPPER));
return QueryFilterParser.parse(queryFilters);
}
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {

View File

@@ -0,0 +1,72 @@
package com.tencent.supersonic.headless.chat.utils;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
public class QueryFilterParser {
public static String parse(QueryFilters queryFilters) {
try {
List<String> conditions = queryFilters.getFilters().stream()
.map(QueryFilterParser::parseFilter)
.collect(Collectors.toList());
return String.join(" AND ", conditions);
} catch (Exception e) {
log.error("", e);
}
return null;
}
private static String parseFilter(QueryFilter filter) {
String column = filter.getName();
FilterOperatorEnum operator = filter.getOperator();
Object value = filter.getValue();
switch (operator) {
case IN:
case NOT_IN:
return column + " " + operator.getValue() + " (" + parseList(value) + ")";
case BETWEEN:
if (value instanceof List && ((List<?>) value).size() == 2) {
List<?> values = (List<?>) value;
return column + " BETWEEN " + formatValue(values.get(0)) + " AND " + formatValue(values.get(1));
}
throw new IllegalArgumentException("BETWEEN operator requires a list of two values");
case IS_NULL:
case IS_NOT_NULL:
return column + " " + operator.getValue();
case EXISTS:
return "EXISTS (" + value + ")";
case SQL_PART:
return value.toString();
default:
return column + " " + operator.getValue() + " " + formatValue(value);
}
}
private static String parseList(Object value) {
if (value instanceof List) {
return ((List<?>) value).stream()
.map(QueryFilterParser::formatValue)
.collect(Collectors.joining(", "));
}
throw new IllegalArgumentException("IN and NOT IN operators require a list of values");
}
private static String formatValue(Object value) {
if (value instanceof String) {
return "'" + value + "'";
} else if (value instanceof Number) {
return value.toString();
} else if (value instanceof Boolean) {
return (Boolean) value ? "TRUE" : "FALSE";
}
throw new IllegalArgumentException("Unsupported value type: " + value.getClass());
}
}

View File

@@ -0,0 +1,63 @@
package com.tencent.supersonic.headless.chat.corrector;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.chat.QueryContext;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class WhereCorrectorTest {
@Test
void addQueryFilter() {
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
QueryContext queryContext = new QueryContext();
QueryFilter filter1 = new QueryFilter();
filter1.setName("age");
filter1.setOperator(FilterOperatorEnum.GREATER_THAN);
filter1.setValue(30);
QueryFilter filter2 = new QueryFilter();
filter2.setName("name");
filter2.setOperator(FilterOperatorEnum.LIKE);
filter2.setValue("John%");
QueryFilter filter3 = new QueryFilter();
filter3.setName("id");
filter3.setOperator(FilterOperatorEnum.IN);
filter3.setValue(Lists.newArrayList(1, 2, 3, 4));
QueryFilter filter4 = new QueryFilter();
filter4.setName("status");
filter4.setOperator(FilterOperatorEnum.NOT_IN);
filter4.setValue(Lists.newArrayList("inactive", "deleted"));
QueryFilters queryFilters = new QueryFilters();
queryFilters.getFilters().add(filter1);
queryFilters.getFilters().add(filter2);
queryFilters.getFilters().add(filter3);
queryFilters.getFilters().add(filter4);
queryContext.setQueryFilters(queryFilters);
WhereCorrector whereCorrector = new WhereCorrector();
whereCorrector.addQueryFilter(queryContext, semanticParseInfo);
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
Assert.assertEquals(correctS2SQL, "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE "
+ "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND "
+ "name LIKE 'John%' AND id IN (1, 2, 3, 4) AND status GROUP BY 维度1");
}
}

View File

@@ -0,0 +1,48 @@
package com.tencent.supersonic.headless.chat.utils;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class QueryFilterParserTest {
@Test
void parse() {
// Example usage
QueryFilter filter1 = new QueryFilter();
filter1.setName("age");
filter1.setOperator(FilterOperatorEnum.GREATER_THAN);
filter1.setValue(30);
QueryFilter filter2 = new QueryFilter();
filter2.setName("name");
filter2.setOperator(FilterOperatorEnum.LIKE);
filter2.setValue("John%");
QueryFilter filter3 = new QueryFilter();
filter3.setName("id");
filter3.setOperator(FilterOperatorEnum.IN);
filter3.setValue(Lists.newArrayList(1, 2, 3, 4));
QueryFilter filter4 = new QueryFilter();
filter4.setName("status");
filter4.setOperator(FilterOperatorEnum.NOT_IN);
filter4.setValue(Lists.newArrayList("inactive", "deleted"));
QueryFilters queryFilters = new QueryFilters();
queryFilters.getFilters().add(filter1);
queryFilters.getFilters().add(filter2);
queryFilters.getFilters().add(filter3);
queryFilters.getFilters().add(filter4);
String parse = QueryFilterParser.parse(queryFilters);
Assert.assertEquals(parse, "age > 30 AND name LIKE 'John%' AND id IN (1, 2, 3, 4)"
+ " AND status NOT_IN ('inactive', 'deleted')");
}
}