(improvement)(Headless) distinct select fields in S2CorrectSQL (#912)

This commit is contained in:
mainmain
2024-04-16 16:30:00 +08:00
committed by GitHub
parent 6e0fc87a57
commit 25a618b175
13 changed files with 87 additions and 18 deletions

View File

@@ -8,7 +8,9 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.service.ChatQueryService; import com.tencent.supersonic.headless.server.service.ChatQueryService;
import java.util.List; import java.util.List;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class NL2SQLParser implements ChatParser { public class NL2SQLParser implements ChatParser {
@Override @Override

View File

@@ -35,4 +35,5 @@ public interface ChatQueryRepository {
List<ChatParseDO> getParseInfoList(List<Long> questionIds); List<ChatParseDO> getParseInfoList(List<Long> questionIds);
Boolean deleteChatQuery(Long questionId); Boolean deleteChatQuery(Long questionId);
} }

View File

@@ -192,4 +192,5 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
public Boolean deleteChatQuery(Long questionId) { public Boolean deleteChatQuery(Long questionId) {
return chatQueryDOMapper.deleteByPrimaryKey(questionId); return chatQueryDOMapper.deleteByPrimaryKey(questionId);
} }
} }

View File

@@ -5,6 +5,7 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.ArrayList; import java.util.ArrayList;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.Function;
@@ -12,6 +13,7 @@ import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Column;
@@ -225,7 +227,7 @@ public class SqlAddHelper {
} }
private static void addAggregateToSelectItems(List<SelectItem> selectItems, private static void addAggregateToSelectItems(List<SelectItem> selectItems,
Map<String, String> fieldNameToAggregate) { Map<String, String> fieldNameToAggregate) {
for (SelectItem selectItem : selectItems) { for (SelectItem selectItem : selectItems) {
if (selectItem instanceof SelectExpressionItem) { if (selectItem instanceof SelectExpressionItem) {
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem; SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
@@ -240,7 +242,7 @@ public class SqlAddHelper {
} }
private static void addAggregateToOrderByItems(List<OrderByElement> orderByElements, private static void addAggregateToOrderByItems(List<OrderByElement> orderByElements,
Map<String, String> fieldNameToAggregate) { Map<String, String> fieldNameToAggregate) {
if (orderByElements == null) { if (orderByElements == null) {
return; return;
} }
@@ -255,7 +257,7 @@ public class SqlAddHelper {
} }
private static void addAggregateToGroupByItems(GroupByElement groupByElement, private static void addAggregateToGroupByItems(GroupByElement groupByElement,
Map<String, String> fieldNameToAggregate) { Map<String, String> fieldNameToAggregate) {
if (groupByElement == null) { if (groupByElement == null) {
return; return;
} }
@@ -276,13 +278,22 @@ public class SqlAddHelper {
} }
private static void modifyWhereExpression(Expression whereExpression, private static void modifyWhereExpression(Expression whereExpression,
Map<String, String> fieldNameToAggregate) { Map<String, String> fieldNameToAggregate) {
if (SqlSelectHelper.isLogicExpression(whereExpression)) { if (SqlSelectHelper.isLogicExpression(whereExpression)) {
AndExpression andExpression = (AndExpression) whereExpression; if (whereExpression instanceof AndExpression) {
Expression leftExpression = andExpression.getLeftExpression(); AndExpression andExpression = (AndExpression) whereExpression;
Expression rightExpression = andExpression.getRightExpression(); Expression leftExpression = andExpression.getLeftExpression();
modifyWhereExpression(leftExpression, fieldNameToAggregate); Expression rightExpression = andExpression.getRightExpression();
modifyWhereExpression(rightExpression, fieldNameToAggregate); modifyWhereExpression(leftExpression, fieldNameToAggregate);
modifyWhereExpression(rightExpression, fieldNameToAggregate);
}
if (whereExpression instanceof OrExpression) {
OrExpression orExpression = (OrExpression) whereExpression;
Expression leftExpression = orExpression.getLeftExpression();
Expression rightExpression = orExpression.getRightExpression();
modifyWhereExpression(leftExpression, fieldNameToAggregate);
modifyWhereExpression(rightExpression, fieldNameToAggregate);
}
} else if (whereExpression instanceof Parenthesis) { } else if (whereExpression instanceof Parenthesis) {
modifyWhereExpression(((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate); modifyWhereExpression(((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate);
} else { } else {

View File

@@ -29,6 +29,7 @@ import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.Objects; import java.util.Objects;
@@ -60,6 +61,32 @@ public class SqlRemoveHelper {
return selectStatement.toString(); return selectStatement.toString();
} }
public static String removeSameFieldFromSelect(String sql) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (selectStatement == null) {
return sql;
}
SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectBody instanceof PlainSelect)) {
return sql;
}
List<SelectItem> selectItems = ((PlainSelect) selectBody).getSelectItems();
Set<String> fields = new HashSet<>();
selectItems.removeIf(selectItem -> {
if (selectItem instanceof SelectExpressionItem) {
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
String field = selectExpressionItem.getExpression().toString();
if (fields.contains(field)) {
return true;
}
fields.add(field);
}
return false;
});
((PlainSelect) selectBody).setSelectItems(selectItems);
return selectStatement.toString();
}
public static String removeWhereCondition(String sql, Set<String> removeFieldNames) { public static String removeWhereCondition(String sql, Set<String> removeFieldNames) {
Select selectStatement = SqlSelectHelper.getSelect(sql); Select selectStatement = SqlSelectHelper.getSelect(sql);
SelectBody selectBody = selectStatement.getSelectBody(); SelectBody selectBody = selectStatement.getSelectBody();

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.common.util.jsqlparser; package com.tencent.supersonic.common.util.jsqlparser;
import java.util.HashSet; import java.util.HashSet;
import java.util.List;
import java.util.Set; import java.util.Set;
import org.junit.Assert; import org.junit.Assert;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@@ -10,6 +11,21 @@ import org.junit.jupiter.api.Test;
*/ */
class SqlRemoveHelperTest { class SqlRemoveHelperTest {
@Test
void testRemoveSameFieldFromSelect() {
String sql = "select 歌曲名,歌手名,粉丝数,粉丝数,sum(粉丝数),sum(粉丝数),avg(播放量),avg(播放量)"
+ " from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and "
+ "sum(播放量) > 20000 and 1=1 HAVING sum(播放量) > 20000 and 3>1";
sql = SqlRemoveHelper.removeSameFieldFromSelect(sql);
System.out.println(sql);
sql = "SELECT 结算播放量 FROM 艺人 WHERE (歌手名 IN ('林俊杰', '陈奕迅')) AND (数据日期 >= '2024-04-04' AND 数据日期 <= '2024-04-04')";
List<FieldExpression> fieldExpressionList = SqlSelectHelper.getWhereExpressions(sql);
fieldExpressionList.stream().forEach(fieldExpression -> {
System.out.println(fieldExpression.toString());
});
}
@Test @Test
void testRemoveWhereHavingCondition() { void testRemoveWhereHavingCondition() {
String sql = "select 歌曲名 from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and " String sql = "select 歌曲名 from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and "

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.core.chat.corrector; package com.tencent.supersonic.headless.core.chat.corrector;
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -28,5 +29,12 @@ public class GrammarCorrector extends BaseSemanticCorrector {
for (BaseSemanticCorrector corrector : correctors) { for (BaseSemanticCorrector corrector : correctors) {
corrector.correct(queryContext, semanticParseInfo); corrector.correct(queryContext, semanticParseInfo);
} }
removeSameFieldFromSelect(semanticParseInfo);
}
public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
} }
} }

View File

@@ -61,11 +61,12 @@ public class LLMRequestService {
} }
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId, public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) { SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName(); Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
String queryText = queryCtx.getQueryText(); String queryText = queryCtx.getQueryText();
LLMReq llmReq = new LLMReq(); LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText); llmReq.setQueryText(queryText);
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition(); LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
llmReq.setFilterCondition(filterCondition); llmReq.setFilterCondition(filterCondition);
@@ -103,7 +104,7 @@ public class LLMRequestService {
} }
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId, protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
LLMParserConfig llmParserConfig) { LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig); Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);

View File

@@ -11,6 +11,7 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils; import org.apache.commons.collections4.MapUtils;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
@@ -28,6 +29,7 @@ public class LLMSqlParser implements SemanticParser {
try { try {
//2.get dataSetId from queryCtx and chatCtx. //2.get dataSetId from queryCtx and chatCtx.
Long dataSetId = requestService.getDataSetId(queryCtx); Long dataSetId = requestService.getDataSetId(queryCtx);
log.info("dataSetId:{}", dataSetId);
if (dataSetId == null) { if (dataSetId == null) {
return; return;
} }

View File

@@ -6,11 +6,11 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.List;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Map;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.List;
import java.util.Map;
@Component @Component
@Slf4j @Slf4j
@@ -85,7 +85,7 @@ public class SqlPromptGenerator {
} }
public List<List<Map<String, String>>> getExampleCombos(List<Map<String, String>> exampleList, int numFewShots, public List<List<Map<String, String>>> getExampleCombos(List<Map<String, String>> exampleList, int numFewShots,
int numSelfConsistency) { int numSelfConsistency) {
List<List<Map<String, String>>> results = new ArrayList<>(); List<List<Map<String, String>>> results = new ArrayList<>();
for (int i = 0; i < numSelfConsistency; i++) { for (int i = 0; i < numSelfConsistency; i++) {
List<Map<String, String>> shuffledList = new ArrayList<>(exampleList); List<Map<String, String>> shuffledList = new ArrayList<>(exampleList);
@@ -118,7 +118,7 @@ public class SqlPromptGenerator {
} }
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool, public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
List<List<Map<String, String>>> fewshotExampleListPool) { List<List<Map<String, String>>> fewshotExampleListPool) {
List<String> sqlPromptPool = new ArrayList<>(); List<String> sqlPromptPool = new ArrayList<>();
for (int i = 0; i < schemaLinkStrPool.size(); i++) { for (int i = 0; i < schemaLinkStrPool.size(); i++) {
String schemaLinkStr = schemaLinkStrPool.get(i); String schemaLinkStr = schemaLinkStrPool.get(i);
@@ -129,4 +129,4 @@ public class SqlPromptGenerator {
return sqlPromptPool; return sqlPromptPool;
} }
} }

View File

@@ -15,6 +15,7 @@ import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;

View File

@@ -8,7 +8,6 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContext
import com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper; import com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository; import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Primary; import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository; import org.springframework.stereotype.Repository;
@@ -17,7 +16,7 @@ import org.springframework.stereotype.Repository;
@Slf4j @Slf4j
public class ChatContextRepositoryImpl implements ChatContextRepository { public class ChatContextRepositoryImpl implements ChatContextRepository {
@Autowired(required = false)
private final ChatContextMapper chatContextMapper; private final ChatContextMapper chatContextMapper;
public ChatContextRepositoryImpl(ChatContextMapper chatContextMapper) { public ChatContextRepositoryImpl(ChatContextMapper chatContextMapper) {