mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 06:27:21 +00:00
[improvement](chat) remove duplicates from multiple SQL identified by LLM. (#391)
This commit is contained in:
@@ -7,19 +7,23 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LLMResponseService {
|
||||
|
||||
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) {
|
||||
if (Objects.isNull(weight)) {
|
||||
weight = 0D;
|
||||
@@ -51,4 +55,19 @@ public class LLMResponseService {
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
public Map<String, Double> getDeduplicationSqlWeight(LLMResp llmResp) {
|
||||
if (MapUtils.isEmpty(llmResp.getSqlWeight())) {
|
||||
return llmResp.getSqlWeight();
|
||||
}
|
||||
Map<String, Double> result = new HashMap<>();
|
||||
for (Map.Entry<String, Double> entry : llmResp.getSqlWeight().entrySet()) {
|
||||
String key = entry.getKey();
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlParserEqualHelper.equals(existKey, key))) {
|
||||
continue;
|
||||
}
|
||||
result.put(key, entry.getValue());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
|
||||
@Slf4j
|
||||
public class LLMS2SQLParser implements SemanticParser {
|
||||
@@ -45,8 +46,9 @@ public class LLMS2SQLParser implements SemanticParser {
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
//5. get and update parserInfo
|
||||
Map<String, Double> sqlWeight = llmResp.getSqlWeight();
|
||||
//5. deduplicate the SQL result list and build parserInfo
|
||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||
Map<String, Double> deduplicationSqlWeight = responseService.getDeduplicationSqlWeight(llmResp);
|
||||
ParseResult parseResult = ParseResult.builder()
|
||||
.request(request)
|
||||
.modelId(modelId)
|
||||
@@ -56,12 +58,10 @@ public class LLMS2SQLParser implements SemanticParser {
|
||||
.linkingValues(linkingValues)
|
||||
.build();
|
||||
|
||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||
|
||||
if (Objects.isNull(sqlWeight) || sqlWeight.isEmpty()) {
|
||||
if (MapUtils.isEmpty(deduplicationSqlWeight)) {
|
||||
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
|
||||
} else {
|
||||
sqlWeight.forEach((sql, weight) -> {
|
||||
deduplicationSqlWeight.forEach((sql, weight) -> {
|
||||
responseService.addParseInfo(queryCtx, parseResult, sql, weight);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
@@ -47,7 +47,7 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
return;
|
||||
}
|
||||
|
||||
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL);
|
||||
List<FieldExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL);
|
||||
//set dataInfo
|
||||
try {
|
||||
if (!CollectionUtils.isEmpty(expressions)) {
|
||||
@@ -112,9 +112,9 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
|
||||
|
||||
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
|
||||
List<FilterExpression> filterExpressions) {
|
||||
List<FieldExpression> fieldExpressions) {
|
||||
List<QueryFilter> result = Lists.newArrayList();
|
||||
for (FilterExpression expression : filterExpressions) {
|
||||
for (FieldExpression expression : fieldExpressions) {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
dimensionFilter.setValue(expression.getFieldValue());
|
||||
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
|
||||
@@ -133,8 +133,8 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
return result;
|
||||
}
|
||||
|
||||
private DateConf getDateInfo(List<FilterExpression> filterExpressions) {
|
||||
List<FilterExpression> dateExpressions = filterExpressions.stream()
|
||||
private DateConf getDateInfo(List<FieldExpression> fieldExpressions) {
|
||||
List<FieldExpression> dateExpressions = fieldExpressions.stream()
|
||||
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
|
||||
.collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(dateExpressions)) {
|
||||
@@ -142,7 +142,7 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
}
|
||||
DateConf dateInfo = new DateConf();
|
||||
dateInfo.setDateMode(DateMode.BETWEEN);
|
||||
FilterExpression firstExpression = dateExpressions.get(0);
|
||||
FieldExpression firstExpression = dateExpressions.get(0);
|
||||
|
||||
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
|
||||
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
|
||||
@@ -168,12 +168,12 @@ public class ParserInfoServiceImpl implements ParseInfoService {
|
||||
return dateInfo;
|
||||
}
|
||||
|
||||
private boolean containOperators(FilterExpression expression, FilterOperatorEnum firstOperator,
|
||||
private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator,
|
||||
FilterOperatorEnum... operatorEnums) {
|
||||
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
|
||||
}
|
||||
|
||||
private boolean hasSecondDate(List<FilterExpression> dateExpressions) {
|
||||
private boolean hasSecondDate(List<FieldExpression> dateExpressions) {
|
||||
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
|
||||
}
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FilterExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
@@ -296,8 +296,8 @@ public class QueryServiceImpl implements QueryService {
|
||||
String correctorSql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
log.info("correctorSql before replacing:{}", correctorSql);
|
||||
// get where filter and having filter
|
||||
List<FilterExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
|
||||
List<FilterExpression> havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql);
|
||||
List<FieldExpression> whereExpressionList = SqlParserSelectHelper.getWhereExpressions(correctorSql);
|
||||
List<FieldExpression> havingExpressionList = SqlParserSelectHelper.getHavingExpressions(correctorSql);
|
||||
List<Expression> addWhereConditions = new ArrayList<>();
|
||||
List<Expression> addHavingConditions = new ArrayList<>();
|
||||
Set<String> removeWhereFieldNames = new HashSet<>();
|
||||
@@ -350,7 +350,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
|
||||
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
|
||||
Map<String, Map<String, String>> filedNameToValueMap,
|
||||
List<FilterExpression> filterExpressionList,
|
||||
List<FieldExpression> fieldExpressionList,
|
||||
List<Expression> addConditions,
|
||||
Set<String> removeFieldNames) {
|
||||
if (Objects.isNull(queryData.getDateInfo())) {
|
||||
@@ -364,12 +364,12 @@ public class QueryServiceImpl implements QueryService {
|
||||
}
|
||||
// startDate equals to endDate
|
||||
if (queryData.getDateInfo().getStartDate().equals(queryData.getDateInfo().getEndDate())) {
|
||||
for (FilterExpression filterExpression : filterExpressionList) {
|
||||
if (TimeDimensionEnum.DAY.getChName().equals(filterExpression.getFieldName())) {
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
|
||||
//sql where condition exists 'equals' operator about date,just replace
|
||||
if (filterExpression.getOperator().equals(FilterOperatorEnum.EQUALS)) {
|
||||
dateField = filterExpression.getFieldName();
|
||||
map.put(filterExpression.getFieldValue().toString(),
|
||||
if (fieldExpression.getOperator().equals(FilterOperatorEnum.EQUALS)) {
|
||||
dateField = fieldExpression.getFieldName();
|
||||
map.put(fieldExpression.getFieldValue().toString(),
|
||||
queryData.getDateInfo().getStartDate());
|
||||
filedNameToValueMap.put(dateField, map);
|
||||
} else {
|
||||
@@ -386,23 +386,23 @@ public class QueryServiceImpl implements QueryService {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (FilterExpression filterExpression : filterExpressionList) {
|
||||
if (TimeDimensionEnum.DAY.getChName().equals(filterExpression.getFieldName())) {
|
||||
dateField = filterExpression.getFieldName();
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
if (TimeDimensionEnum.DAY.getChName().equals(fieldExpression.getFieldName())) {
|
||||
dateField = fieldExpression.getFieldName();
|
||||
//just replace
|
||||
if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(filterExpression.getOperator())
|
||||
|| FilterOperatorEnum.GREATER_THAN.getValue().equals(filterExpression.getOperator())) {
|
||||
map.put(filterExpression.getFieldValue().toString(),
|
||||
if (FilterOperatorEnum.GREATER_THAN_EQUALS.getValue().equals(fieldExpression.getOperator())
|
||||
|| FilterOperatorEnum.GREATER_THAN.getValue().equals(fieldExpression.getOperator())) {
|
||||
map.put(fieldExpression.getFieldValue().toString(),
|
||||
queryData.getDateInfo().getStartDate());
|
||||
}
|
||||
if (FilterOperatorEnum.MINOR_THAN_EQUALS.getValue().equals(filterExpression.getOperator())
|
||||
|| FilterOperatorEnum.MINOR_THAN.getValue().equals(filterExpression.getOperator())) {
|
||||
map.put(filterExpression.getFieldValue().toString(),
|
||||
if (FilterOperatorEnum.MINOR_THAN_EQUALS.getValue().equals(fieldExpression.getOperator())
|
||||
|| FilterOperatorEnum.MINOR_THAN.getValue().equals(fieldExpression.getOperator())) {
|
||||
map.put(fieldExpression.getFieldValue().toString(),
|
||||
queryData.getDateInfo().getEndDate());
|
||||
}
|
||||
filedNameToValueMap.put(dateField, map);
|
||||
// first remove,then add
|
||||
if (FilterOperatorEnum.EQUALS.getValue().equals(filterExpression.getOperator())) {
|
||||
if (FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator())) {
|
||||
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
|
||||
GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
|
||||
addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, addConditions);
|
||||
@@ -425,7 +425,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
addConditions.add(comparisonExpression);
|
||||
}
|
||||
|
||||
private void updateFilters(List<FilterExpression> filterExpressionList,
|
||||
private void updateFilters(List<FieldExpression> fieldExpressionList,
|
||||
Set<QueryFilter> metricFilters,
|
||||
Set<QueryFilter> contextMetricFilters,
|
||||
List<Expression> addConditions,
|
||||
@@ -434,9 +434,9 @@ public class QueryServiceImpl implements QueryService {
|
||||
return;
|
||||
}
|
||||
for (QueryFilter dslQueryFilter : metricFilters) {
|
||||
for (FilterExpression filterExpression : filterExpressionList) {
|
||||
if (filterExpression.getFieldName() != null
|
||||
&& filterExpression.getFieldName().contains(dslQueryFilter.getName())) {
|
||||
for (FieldExpression fieldExpression : fieldExpressionList) {
|
||||
if (fieldExpression.getFieldName() != null
|
||||
&& fieldExpression.getFieldName().contains(dslQueryFilter.getName())) {
|
||||
removeFieldNames.add(dslQueryFilter.getName());
|
||||
if (dslQueryFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
|
||||
EqualsTo equalsTo = new EqualsTo();
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
package com.tencent.supersonic.chat.parser.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class LLMResponseServiceTest {
|
||||
|
||||
@Test
|
||||
void deduplicationSqlWeight() {
|
||||
String sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
String sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
Map<String, Double> sqlWeight = new HashMap<>();
|
||||
sqlWeight.put(sql1, 0.2D);
|
||||
sqlWeight.put(sql2, 0.8D);
|
||||
llmResp.setSqlWeight(sqlWeight);
|
||||
LLMResponseService llmResponseService = new LLMResponseService();
|
||||
Map<String, Double> deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
|
||||
|
||||
sql1 = "SELECT a,b,c,d FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp2 = new LLMResp();
|
||||
Map<String, Double> sqlWeight2 = new HashMap<>();
|
||||
sqlWeight2.put(sql1, 0.2D);
|
||||
sqlWeight2.put(sql2, 0.8D);
|
||||
llmResp2.setSqlWeight(sqlWeight2);
|
||||
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp2);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlWeight.size(), 1);
|
||||
|
||||
sql1 = "SELECT a,b,c,d,e FROM table1 WHERE column1 = 1 AND column2 = 2 order by a";
|
||||
sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a";
|
||||
|
||||
LLMResp llmResp3 = new LLMResp();
|
||||
Map<String, Double> sqlWeight3 = new HashMap<>();
|
||||
sqlWeight3.put(sql1, 0.2D);
|
||||
sqlWeight3.put(sql2, 0.8D);
|
||||
llmResp3.setSqlWeight(sqlWeight3);
|
||||
deduplicationSqlWeight = llmResponseService.getDeduplicationSqlWeight(llmResp3);
|
||||
|
||||
Assert.assertEquals(deduplicationSqlWeight.size(), 2);
|
||||
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user