mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(Headless) distinct select fields in S2CorrectSQL (#912)
This commit is contained in:
@@ -8,7 +8,9 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
public class NL2SQLParser implements ChatParser {
|
public class NL2SQLParser implements ChatParser {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -35,4 +35,5 @@ public interface ChatQueryRepository {
|
|||||||
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
|
List<ChatParseDO> getParseInfoList(List<Long> questionIds);
|
||||||
|
|
||||||
Boolean deleteChatQuery(Long questionId);
|
Boolean deleteChatQuery(Long questionId);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -192,4 +192,5 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
|||||||
public Boolean deleteChatQuery(Long questionId) {
|
public Boolean deleteChatQuery(Long questionId) {
|
||||||
return chatQueryDOMapper.deleteByPrimaryKey(questionId);
|
return chatQueryDOMapper.deleteByPrimaryKey(questionId);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import java.util.Map;
|
|||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import net.sf.jsqlparser.expression.Function;
|
import net.sf.jsqlparser.expression.Function;
|
||||||
@@ -12,6 +13,7 @@ import net.sf.jsqlparser.expression.LongValue;
|
|||||||
import net.sf.jsqlparser.expression.Parenthesis;
|
import net.sf.jsqlparser.expression.Parenthesis;
|
||||||
import net.sf.jsqlparser.expression.StringValue;
|
import net.sf.jsqlparser.expression.StringValue;
|
||||||
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
|
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
|
||||||
|
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
|
||||||
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator;
|
||||||
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
|
||||||
import net.sf.jsqlparser.schema.Column;
|
import net.sf.jsqlparser.schema.Column;
|
||||||
@@ -225,7 +227,7 @@ public class SqlAddHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static void addAggregateToSelectItems(List<SelectItem> selectItems,
|
private static void addAggregateToSelectItems(List<SelectItem> selectItems,
|
||||||
Map<String, String> fieldNameToAggregate) {
|
Map<String, String> fieldNameToAggregate) {
|
||||||
for (SelectItem selectItem : selectItems) {
|
for (SelectItem selectItem : selectItems) {
|
||||||
if (selectItem instanceof SelectExpressionItem) {
|
if (selectItem instanceof SelectExpressionItem) {
|
||||||
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
|
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
|
||||||
@@ -240,7 +242,7 @@ public class SqlAddHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static void addAggregateToOrderByItems(List<OrderByElement> orderByElements,
|
private static void addAggregateToOrderByItems(List<OrderByElement> orderByElements,
|
||||||
Map<String, String> fieldNameToAggregate) {
|
Map<String, String> fieldNameToAggregate) {
|
||||||
if (orderByElements == null) {
|
if (orderByElements == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -255,7 +257,7 @@ public class SqlAddHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static void addAggregateToGroupByItems(GroupByElement groupByElement,
|
private static void addAggregateToGroupByItems(GroupByElement groupByElement,
|
||||||
Map<String, String> fieldNameToAggregate) {
|
Map<String, String> fieldNameToAggregate) {
|
||||||
if (groupByElement == null) {
|
if (groupByElement == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -276,13 +278,22 @@ public class SqlAddHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static void modifyWhereExpression(Expression whereExpression,
|
private static void modifyWhereExpression(Expression whereExpression,
|
||||||
Map<String, String> fieldNameToAggregate) {
|
Map<String, String> fieldNameToAggregate) {
|
||||||
if (SqlSelectHelper.isLogicExpression(whereExpression)) {
|
if (SqlSelectHelper.isLogicExpression(whereExpression)) {
|
||||||
AndExpression andExpression = (AndExpression) whereExpression;
|
if (whereExpression instanceof AndExpression) {
|
||||||
Expression leftExpression = andExpression.getLeftExpression();
|
AndExpression andExpression = (AndExpression) whereExpression;
|
||||||
Expression rightExpression = andExpression.getRightExpression();
|
Expression leftExpression = andExpression.getLeftExpression();
|
||||||
modifyWhereExpression(leftExpression, fieldNameToAggregate);
|
Expression rightExpression = andExpression.getRightExpression();
|
||||||
modifyWhereExpression(rightExpression, fieldNameToAggregate);
|
modifyWhereExpression(leftExpression, fieldNameToAggregate);
|
||||||
|
modifyWhereExpression(rightExpression, fieldNameToAggregate);
|
||||||
|
}
|
||||||
|
if (whereExpression instanceof OrExpression) {
|
||||||
|
OrExpression orExpression = (OrExpression) whereExpression;
|
||||||
|
Expression leftExpression = orExpression.getLeftExpression();
|
||||||
|
Expression rightExpression = orExpression.getRightExpression();
|
||||||
|
modifyWhereExpression(leftExpression, fieldNameToAggregate);
|
||||||
|
modifyWhereExpression(rightExpression, fieldNameToAggregate);
|
||||||
|
}
|
||||||
} else if (whereExpression instanceof Parenthesis) {
|
} else if (whereExpression instanceof Parenthesis) {
|
||||||
modifyWhereExpression(((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate);
|
modifyWhereExpression(((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ import net.sf.jsqlparser.statement.select.SelectItem;
|
|||||||
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
@@ -60,6 +61,32 @@ public class SqlRemoveHelper {
|
|||||||
return selectStatement.toString();
|
return selectStatement.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static String removeSameFieldFromSelect(String sql) {
|
||||||
|
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||||
|
if (selectStatement == null) {
|
||||||
|
return sql;
|
||||||
|
}
|
||||||
|
SelectBody selectBody = selectStatement.getSelectBody();
|
||||||
|
if (!(selectBody instanceof PlainSelect)) {
|
||||||
|
return sql;
|
||||||
|
}
|
||||||
|
List<SelectItem> selectItems = ((PlainSelect) selectBody).getSelectItems();
|
||||||
|
Set<String> fields = new HashSet<>();
|
||||||
|
selectItems.removeIf(selectItem -> {
|
||||||
|
if (selectItem instanceof SelectExpressionItem) {
|
||||||
|
SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem;
|
||||||
|
String field = selectExpressionItem.getExpression().toString();
|
||||||
|
if (fields.contains(field)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
fields.add(field);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
});
|
||||||
|
((PlainSelect) selectBody).setSelectItems(selectItems);
|
||||||
|
return selectStatement.toString();
|
||||||
|
}
|
||||||
|
|
||||||
public static String removeWhereCondition(String sql, Set<String> removeFieldNames) {
|
public static String removeWhereCondition(String sql, Set<String> removeFieldNames) {
|
||||||
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
Select selectStatement = SqlSelectHelper.getSelect(sql);
|
||||||
SelectBody selectBody = selectStatement.getSelectBody();
|
SelectBody selectBody = selectStatement.getSelectBody();
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.common.util.jsqlparser;
|
package com.tencent.supersonic.common.util.jsqlparser;
|
||||||
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
@@ -10,6 +11,21 @@ import org.junit.jupiter.api.Test;
|
|||||||
*/
|
*/
|
||||||
class SqlRemoveHelperTest {
|
class SqlRemoveHelperTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void testRemoveSameFieldFromSelect() {
|
||||||
|
String sql = "select 歌曲名,歌手名,粉丝数,粉丝数,sum(粉丝数),sum(粉丝数),avg(播放量),avg(播放量)"
|
||||||
|
+ " from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and "
|
||||||
|
+ "sum(播放量) > 20000 and 1=1 HAVING sum(播放量) > 20000 and 3>1";
|
||||||
|
sql = SqlRemoveHelper.removeSameFieldFromSelect(sql);
|
||||||
|
System.out.println(sql);
|
||||||
|
sql = "SELECT 结算播放量 FROM 艺人 WHERE (歌手名 IN ('林俊杰', '陈奕迅')) AND (数据日期 >= '2024-04-04' AND 数据日期 <= '2024-04-04')";
|
||||||
|
List<FieldExpression> fieldExpressionList = SqlSelectHelper.getWhereExpressions(sql);
|
||||||
|
fieldExpressionList.stream().forEach(fieldExpression -> {
|
||||||
|
System.out.println(fieldExpression.toString());
|
||||||
|
});
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void testRemoveWhereHavingCondition() {
|
void testRemoveWhereHavingCondition() {
|
||||||
String sql = "select 歌曲名 from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and "
|
String sql = "select 歌曲名 from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and "
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.core.chat.corrector;
|
package com.tencent.supersonic.headless.core.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -28,5 +29,12 @@ public class GrammarCorrector extends BaseSemanticCorrector {
|
|||||||
for (BaseSemanticCorrector corrector : correctors) {
|
for (BaseSemanticCorrector corrector : correctors) {
|
||||||
corrector.correct(queryContext, semanticParseInfo);
|
corrector.correct(queryContext, semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
removeSameFieldFromSelect(semanticParseInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) {
|
||||||
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,11 +61,12 @@ public class LLMRequestService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
|
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
|
||||||
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
|
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
|
||||||
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
|
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
|
||||||
String queryText = queryCtx.getQueryText();
|
String queryText = queryCtx.getQueryText();
|
||||||
|
|
||||||
LLMReq llmReq = new LLMReq();
|
LLMReq llmReq = new LLMReq();
|
||||||
|
|
||||||
llmReq.setQueryText(queryText);
|
llmReq.setQueryText(queryText);
|
||||||
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
|
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
|
||||||
llmReq.setFilterCondition(filterCondition);
|
llmReq.setFilterCondition(filterCondition);
|
||||||
@@ -103,7 +104,7 @@ public class LLMRequestService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
||||||
LLMParserConfig llmParserConfig) {
|
LLMParserConfig llmParserConfig) {
|
||||||
|
|
||||||
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
|
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
|||||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
|
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections4.MapUtils;
|
import org.apache.commons.collections4.MapUtils;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
@@ -28,6 +29,7 @@ public class LLMSqlParser implements SemanticParser {
|
|||||||
try {
|
try {
|
||||||
//2.get dataSetId from queryCtx and chatCtx.
|
//2.get dataSetId from queryCtx and chatCtx.
|
||||||
Long dataSetId = requestService.getDataSetId(queryCtx);
|
Long dataSetId = requestService.getDataSetId(queryCtx);
|
||||||
|
log.info("dataSetId:{}", dataSetId);
|
||||||
if (dataSetId == null) {
|
if (dataSetId == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,11 +6,11 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -85,7 +85,7 @@ public class SqlPromptGenerator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public List<List<Map<String, String>>> getExampleCombos(List<Map<String, String>> exampleList, int numFewShots,
|
public List<List<Map<String, String>>> getExampleCombos(List<Map<String, String>> exampleList, int numFewShots,
|
||||||
int numSelfConsistency) {
|
int numSelfConsistency) {
|
||||||
List<List<Map<String, String>>> results = new ArrayList<>();
|
List<List<Map<String, String>>> results = new ArrayList<>();
|
||||||
for (int i = 0; i < numSelfConsistency; i++) {
|
for (int i = 0; i < numSelfConsistency; i++) {
|
||||||
List<Map<String, String>> shuffledList = new ArrayList<>(exampleList);
|
List<Map<String, String>> shuffledList = new ArrayList<>(exampleList);
|
||||||
@@ -118,7 +118,7 @@ public class SqlPromptGenerator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
|
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
|
||||||
List<List<Map<String, String>>> fewshotExampleListPool) {
|
List<List<Map<String, String>>> fewshotExampleListPool) {
|
||||||
List<String> sqlPromptPool = new ArrayList<>();
|
List<String> sqlPromptPool = new ArrayList<>();
|
||||||
for (int i = 0; i < schemaLinkStrPool.size(); i++) {
|
for (int i = 0; i < schemaLinkStrPool.size(); i++) {
|
||||||
String schemaLinkStr = schemaLinkStrPool.get(i);
|
String schemaLinkStr = schemaLinkStrPool.get(i);
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import lombok.Builder;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContext
|
|||||||
import com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper;
|
import com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper;
|
||||||
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
|
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.context.annotation.Primary;
|
import org.springframework.context.annotation.Primary;
|
||||||
import org.springframework.stereotype.Repository;
|
import org.springframework.stereotype.Repository;
|
||||||
|
|
||||||
@@ -17,7 +16,7 @@ import org.springframework.stereotype.Repository;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class ChatContextRepositoryImpl implements ChatContextRepository {
|
public class ChatContextRepositoryImpl implements ChatContextRepository {
|
||||||
|
|
||||||
@Autowired(required = false)
|
|
||||||
private final ChatContextMapper chatContextMapper;
|
private final ChatContextMapper chatContextMapper;
|
||||||
|
|
||||||
public ChatContextRepositoryImpl(ChatContextMapper chatContextMapper) {
|
public ChatContextRepositoryImpl(ChatContextMapper chatContextMapper) {
|
||||||
|
|||||||
Reference in New Issue
Block a user