From 25a618b1756ab9f7cb31ba4002edd7b2fffb57a7 Mon Sep 17 00:00:00 2001 From: mainmain <57514971+mainmainer@users.noreply.github.com> Date: Tue, 16 Apr 2024 16:30:00 +0800 Subject: [PATCH] (improvement)(Headless) distinct select fields in S2CorrectSQL (#912) --- .../chat/server/parser/NL2SQLParser.java | 2 ++ .../repository/ChatQueryRepository.java | 1 + .../impl/ChatQueryRepositoryImpl.java | 1 + .../processor/parse/MetricCheckProcessor.java | 0 .../common/util/jsqlparser/SqlAddHelper.java | 29 +++++++++++++------ .../util/jsqlparser/SqlRemoveHelper.java | 27 +++++++++++++++++ .../util/jsqlparser/SqlRemoveHelperTest.java | 16 ++++++++++ .../core/chat/corrector/GrammarCorrector.java | 8 +++++ .../chat/parser/llm/LLMRequestService.java | 5 ++-- .../core/chat/parser/llm/LLMSqlParser.java | 2 ++ .../chat/parser/llm/SqlPromptGenerator.java | 10 +++---- .../headless/core/pojo/QueryContext.java | 1 + .../impl/ChatContextRepositoryImpl.java | 3 +- 13 files changed, 87 insertions(+), 18 deletions(-) create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/MetricCheckProcessor.java diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index be0ac2efa..5da72ec1f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -8,7 +8,9 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.server.service.ChatQueryService; import java.util.List; +import lombok.extern.slf4j.Slf4j; +@Slf4j public class NL2SQLParser implements ChatParser { @Override diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java index 9f0d46c68..7255ba78e 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java @@ -35,4 +35,5 @@ public interface ChatQueryRepository { List getParseInfoList(List questionIds); Boolean deleteChatQuery(Long questionId); + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java index f07f626c9..0c9e07a4b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java @@ -192,4 +192,5 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { public Boolean deleteChatQuery(Long questionId) { return chatQueryDOMapper.deleteByPrimaryKey(questionId); } + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/MetricCheckProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/MetricCheckProcessor.java new file mode 100644 index 000000000..e69de29bb diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlAddHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlAddHelper.java index ea2d26d55..ffd03bd82 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlAddHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlAddHelper.java @@ -5,6 +5,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.ArrayList; + import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Function; @@ -12,6 +13,7 @@ import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; +import net.sf.jsqlparser.expression.operators.conditional.OrExpression; import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.schema.Column; @@ -225,7 +227,7 @@ public class SqlAddHelper { } private static void addAggregateToSelectItems(List selectItems, - Map fieldNameToAggregate) { + Map fieldNameToAggregate) { for (SelectItem selectItem : selectItems) { if (selectItem instanceof SelectExpressionItem) { SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem; @@ -240,7 +242,7 @@ public class SqlAddHelper { } private static void addAggregateToOrderByItems(List orderByElements, - Map fieldNameToAggregate) { + Map fieldNameToAggregate) { if (orderByElements == null) { return; } @@ -255,7 +257,7 @@ public class SqlAddHelper { } private static void addAggregateToGroupByItems(GroupByElement groupByElement, - Map fieldNameToAggregate) { + Map fieldNameToAggregate) { if (groupByElement == null) { return; } @@ -276,13 +278,22 @@ public class SqlAddHelper { } private static void modifyWhereExpression(Expression whereExpression, - Map fieldNameToAggregate) { + Map fieldNameToAggregate) { if (SqlSelectHelper.isLogicExpression(whereExpression)) { - AndExpression andExpression = (AndExpression) whereExpression; - Expression leftExpression = andExpression.getLeftExpression(); - Expression rightExpression = andExpression.getRightExpression(); - modifyWhereExpression(leftExpression, fieldNameToAggregate); - modifyWhereExpression(rightExpression, fieldNameToAggregate); + if (whereExpression instanceof AndExpression) { + AndExpression andExpression = (AndExpression) whereExpression; + Expression leftExpression = andExpression.getLeftExpression(); + Expression rightExpression = andExpression.getRightExpression(); + modifyWhereExpression(leftExpression, fieldNameToAggregate); + modifyWhereExpression(rightExpression, fieldNameToAggregate); + } + if (whereExpression instanceof OrExpression) { + OrExpression orExpression = (OrExpression) whereExpression; + Expression leftExpression = orExpression.getLeftExpression(); + Expression rightExpression = orExpression.getRightExpression(); + modifyWhereExpression(leftExpression, fieldNameToAggregate); + modifyWhereExpression(rightExpression, fieldNameToAggregate); + } } else if (whereExpression instanceof Parenthesis) { modifyWhereExpression(((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate); } else { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlRemoveHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlRemoveHelper.java index 12c736263..df0671036 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlRemoveHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlRemoveHelper.java @@ -29,6 +29,7 @@ import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import org.springframework.util.CollectionUtils; +import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.Objects; @@ -60,6 +61,32 @@ public class SqlRemoveHelper { return selectStatement.toString(); } + public static String removeSameFieldFromSelect(String sql) { + Select selectStatement = SqlSelectHelper.getSelect(sql); + if (selectStatement == null) { + return sql; + } + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + List selectItems = ((PlainSelect) selectBody).getSelectItems(); + Set fields = new HashSet<>(); + selectItems.removeIf(selectItem -> { + if (selectItem instanceof SelectExpressionItem) { + SelectExpressionItem selectExpressionItem = (SelectExpressionItem) selectItem; + String field = selectExpressionItem.getExpression().toString(); + if (fields.contains(field)) { + return true; + } + fields.add(field); + } + return false; + }); + ((PlainSelect) selectBody).setSelectItems(selectItems); + return selectStatement.toString(); + } + public static String removeWhereCondition(String sql, Set removeFieldNames) { Select selectStatement = SqlSelectHelper.getSelect(sql); SelectBody selectBody = selectStatement.getSelectBody(); diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlRemoveHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlRemoveHelperTest.java index ab06c81d3..8db7d50ba 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlRemoveHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlRemoveHelperTest.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.common.util.jsqlparser; import java.util.HashSet; +import java.util.List; import java.util.Set; import org.junit.Assert; import org.junit.jupiter.api.Test; @@ -10,6 +11,21 @@ import org.junit.jupiter.api.Test; */ class SqlRemoveHelperTest { + @Test + void testRemoveSameFieldFromSelect() { + String sql = "select 歌曲名,歌手名,粉丝数,粉丝数,sum(粉丝数),sum(粉丝数),avg(播放量),avg(播放量)" + + " from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and " + + "sum(播放量) > 20000 and 1=1 HAVING sum(播放量) > 20000 and 3>1"; + sql = SqlRemoveHelper.removeSameFieldFromSelect(sql); + System.out.println(sql); + sql = "SELECT 结算播放量 FROM 艺人 WHERE (歌手名 IN ('林俊杰', '陈奕迅')) AND (数据日期 >= '2024-04-04' AND 数据日期 <= '2024-04-04')"; + List fieldExpressionList = SqlSelectHelper.getWhereExpressions(sql); + fieldExpressionList.stream().forEach(fieldExpression -> { + System.out.println(fieldExpression.toString()); + }); + + } + @Test void testRemoveWhereHavingCondition() { String sql = "select 歌曲名 from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and " diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/GrammarCorrector.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/GrammarCorrector.java index fc22f12ea..944a6633c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/GrammarCorrector.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/corrector/GrammarCorrector.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.headless.core.chat.corrector; +import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.core.pojo.QueryContext; import lombok.extern.slf4j.Slf4j; @@ -28,5 +29,12 @@ public class GrammarCorrector extends BaseSemanticCorrector { for (BaseSemanticCorrector corrector : correctors) { corrector.correct(queryContext, semanticParseInfo); } + removeSameFieldFromSelect(semanticParseInfo); + } + + public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) { + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); + correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL); + semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java index 14141ef87..65a2828c9 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMRequestService.java @@ -61,11 +61,12 @@ public class LLMRequestService { } public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId, - SemanticSchema semanticSchema, List linkingValues) { + SemanticSchema semanticSchema, List linkingValues) { Map dataSetIdToName = semanticSchema.getDataSetIdToName(); String queryText = queryCtx.getQueryText(); LLMReq llmReq = new LLMReq(); + llmReq.setQueryText(queryText); LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition(); llmReq.setFilterCondition(filterCondition); @@ -103,7 +104,7 @@ public class LLMRequestService { } protected List getFieldNameList(QueryContext queryCtx, Long dataSetId, - LLMParserConfig llmParserConfig) { + LLMParserConfig llmParserConfig) { Set results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMSqlParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMSqlParser.java index b95b73092..3250cfb92 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMSqlParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/LLMSqlParser.java @@ -11,6 +11,7 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.MapUtils; + import java.util.List; import java.util.Map; import java.util.Objects; @@ -28,6 +29,7 @@ public class LLMSqlParser implements SemanticParser { try { //2.get dataSetId from queryCtx and chatCtx. Long dataSetId = requestService.getDataSetId(queryCtx); + log.info("dataSetId:{}", dataSetId); if (dataSetId == null) { return; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java index 0e95c15dc..889dabe46 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/SqlPromptGenerator.java @@ -6,11 +6,11 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.tuple.Pair; import org.springframework.stereotype.Component; +import java.util.List; import java.util.ArrayList; +import java.util.Map; import java.util.Arrays; import java.util.Collections; -import java.util.List; -import java.util.Map; @Component @Slf4j @@ -85,7 +85,7 @@ public class SqlPromptGenerator { } public List>> getExampleCombos(List> exampleList, int numFewShots, - int numSelfConsistency) { + int numSelfConsistency) { List>> results = new ArrayList<>(); for (int i = 0; i < numSelfConsistency; i++) { List> shuffledList = new ArrayList<>(exampleList); @@ -118,7 +118,7 @@ public class SqlPromptGenerator { } public List generateSqlPromptPool(LLMReq llmReq, List schemaLinkStrPool, - List>> fewshotExampleListPool) { + List>> fewshotExampleListPool) { List sqlPromptPool = new ArrayList<>(); for (int i = 0; i < schemaLinkStrPool.size(); i++) { String schemaLinkStr = schemaLinkStrPool.get(i); @@ -129,4 +129,4 @@ public class SqlPromptGenerator { return sqlPromptPool; } -} \ No newline at end of file +} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java index 58d442071..c7492ce1d 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java @@ -15,6 +15,7 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; + import java.util.ArrayList; import java.util.Comparator; import java.util.List; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/ChatContextRepositoryImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/ChatContextRepositoryImpl.java index 98c66d3cd..50a1f8d89 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/ChatContextRepositoryImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/ChatContextRepositoryImpl.java @@ -8,7 +8,6 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContext import com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper; import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Primary; import org.springframework.stereotype.Repository; @@ -17,7 +16,7 @@ import org.springframework.stereotype.Repository; @Slf4j public class ChatContextRepositoryImpl implements ChatContextRepository { - @Autowired(required = false) + private final ChatContextMapper chatContextMapper; public ChatContextRepositoryImpl(ChatContextMapper chatContextMapper) {