diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelper.java index a7f578a35..fd13570fd 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelper.java @@ -25,11 +25,13 @@ import net.sf.jsqlparser.statement.select.GroupByElement; import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.SelectItem; +import net.sf.jsqlparser.statement.select.SelectItemVisitorAdapter; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Objects; import java.util.Set; @@ -42,10 +44,7 @@ public class SqlRemoveHelper { public static String removeAsteriskAndAddFields(String sql, Set needAddDefaultFields) { Select selectStatement = SqlSelectHelper.getSelect(sql); - if (Objects.isNull(selectStatement)) { - return sql; - } - if (!(selectStatement instanceof PlainSelect)) { + if (isInvalidSelect(selectStatement)) { return sql; } List> selectItems = ((PlainSelect) selectStatement).getSelectItems(); @@ -63,10 +62,7 @@ public class SqlRemoveHelper { public static String removeSameFieldFromSelect(String sql) { Select selectStatement = SqlSelectHelper.getSelect(sql); - if (selectStatement == null) { - return sql; - } - if (!(selectStatement instanceof PlainSelect)) { + if (isInvalidSelect(selectStatement)) { return sql; } List> selectItems = ((PlainSelect) selectStatement).getSelectItems(); @@ -106,10 +102,7 @@ public class SqlRemoveHelper { public static String removeNumberFilter(String sql) { Select selectStatement = SqlSelectHelper.getSelect(sql); - if (selectStatement == null) { - return sql; - } - if (!(selectStatement instanceof PlainSelect)) { + if (isInvalidSelect(selectStatement)) { return sql; } Expression where = ((PlainSelect) selectStatement).getWhere(); @@ -226,10 +219,7 @@ public class SqlRemoveHelper { public static String removeGroupBy(String sql, Set fields) { Select selectStatement = SqlSelectHelper.getSelect(sql); - if (selectStatement == null) { - return sql; - } - if (!(selectStatement instanceof PlainSelect)) { + if (isInvalidSelect(selectStatement)) { return sql; } GroupByElement groupByElement = ((PlainSelect) selectStatement).getGroupBy(); @@ -250,6 +240,30 @@ public class SqlRemoveHelper { return selectStatement.toString(); } + public static String removeSelect(String sql, Set fields) { + Select selectStatement = SqlSelectHelper.getSelect(sql); + if (isInvalidSelect(selectStatement)) { + return sql; + } + List> selectItems = ((PlainSelect) selectStatement).getSelectItems(); + Iterator> iterator = selectItems.iterator(); + while (iterator.hasNext()) { + SelectItem selectItem = iterator.next(); + selectItem.accept(new SelectItemVisitorAdapter() { + @Override + public void visit(SelectItem item) { + if (fields.contains(item.getExpression().toString())) { + iterator.remove(); + } + } + }); + } + if (selectItems.isEmpty()) { + selectItems.add(new SelectItem(new AllColumns())); + } + return selectStatement.toString(); + } + public static Expression filteredExpression(Expression where, SqlEditEnum sqlEditEnum) throws Exception { if (Objects.isNull(where)) { return null; @@ -339,5 +353,9 @@ public class SqlRemoveHelper { } } + private static boolean isInvalidSelect(Select selectStatement) { + return Objects.isNull(selectStatement) || !(selectStatement instanceof PlainSelect); + } + } diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelperTest.java index d0cab596d..09e859214 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelperTest.java @@ -1,10 +1,11 @@ package com.tencent.supersonic.common.jsqlparser; +import org.junit.Assert; +import org.junit.jupiter.api.Test; + import java.util.HashSet; import java.util.List; import java.util.Set; -import org.junit.Assert; -import org.junit.jupiter.api.Test; /** * SqlParser Remove Helper Test @@ -128,4 +129,38 @@ class SqlRemoveHelperTest { replaceSql); } + @Test + void testRemoveSelect() { + String sql = "select 数据日期,歌曲名 from 歌曲库 where 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时间 = '2023-08-01'"; + + Set removeFieldNames = new HashSet<>(); + removeFieldNames.add("数据日期"); + String replaceSql = SqlRemoveHelper.removeSelect(sql, removeFieldNames); + + Assert.assertEquals( + "SELECT 歌曲名 FROM 歌曲库 WHERE 歌曲名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时间 = '2023-08-01'", + replaceSql); + + sql = "select 数据日期 from 歌曲库 where 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时间 = '2023-08-01'"; + + replaceSql = SqlRemoveHelper.removeSelect(sql, removeFieldNames); + + Assert.assertEquals( + "SELECT * FROM 歌曲库 WHERE 歌曲名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时间 = '2023-08-01'", + replaceSql); + } + + @Test + void testRemoveGroupBy() { + String sql = "select 数据日期 from 歌曲库 where 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and " + + "歌曲发布时间 = '2023-08-01' group by 数据日期"; + + Set removeFieldNames = new HashSet<>(); + removeFieldNames.add("数据日期"); + String replaceSql = SqlRemoveHelper.removeGroupBy(sql, removeFieldNames); + + Assert.assertEquals( + "SELECT 数据日期 FROM 歌曲库 WHERE 歌曲名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时间 = '2023-08-01'", + replaceSql); + } } diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 0d6b22c1f..7d6f65178 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -51,7 +51,7 @@ services: retries: 5 db_init: - image: supersonicbi/supersonic:${SUPERSONIC_VERSION:-latest} + image: supersonicbi/supersonic:${SUPERSONIC_VERSION:-0.9.6} privileged: true platform: linux/amd64 container_name: supersonic_db_init @@ -64,8 +64,8 @@ services: sh -c " sleep 15 && if ! mysql -h supersonic_mysql -usupersonic_user -psupersonic_password -e 'use supersonic_db; show tables;' | grep -q 's2_database'; then - mysql -h supersonic_mysql -usupersonic_user -psupersonic_password supersonic_db < /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}/conf/db/schema-mysql.sql && - mysql -h supersonic_mysql -usupersonic_user -psupersonic_password supersonic_db < /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}/conf/db/data-mysql.sql + mysql -h supersonic_mysql -usupersonic_user -psupersonic_password supersonic_db < /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-0.9.6}/conf/db/schema-mysql.sql && + mysql -h supersonic_mysql -usupersonic_user -psupersonic_password supersonic_db < /usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-0.9.6}/conf/db/data-mysql.sql else echo 'Database already initialized.' fi @@ -76,7 +76,7 @@ services: - 8.8.4.4 supersonic_standalone: - image: supersonicbi/supersonic:${SUPERSONIC_VERSION:-latest} + image: supersonicbi/supersonic:${SUPERSONIC_VERSION:-0.9.6} privileged: true platform: linux/amd64 container_name: supersonic_standalone @@ -103,7 +103,7 @@ services: - 8.8.4.4 volumes: #1.Named Volumes are best for persistent data managed by Docker. - - supersonic_data:/usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest} + - supersonic_data:/usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-0.9.6} #2.Bind Mounts are suitable for frequent modifications and debugging. # - ./conf/langchain4j-prd.yaml:/usr/src/app/supersonic-standalone-${SUPERSONIC_VERSION:-latest}/conf/langchain4j-prd.yaml #3.Detailed Bind Mounts offer more control over the mount behavior. diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java index a7caa7b70..504881292 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java @@ -1,8 +1,10 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; +import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; @@ -61,15 +63,13 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { return elements.stream(); }) .collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1)); - if (chatQueryContext.containsPartitionDimensions(dataSetId)) { - result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName()); - result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName()); - result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName()); + result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName()); + result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName()); + result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName()); - result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName()); - result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName()); - result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName()); - } + result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName()); + result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName()); + result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName()); return result; } @@ -122,4 +122,24 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { dimensions.add(TimeDimensionEnum.DAY.getChName()); return dimensions; } + + protected boolean containsPartitionDimensions(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { + Long dataSetId = semanticParseInfo.getDataSetId(); + SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); + DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); + return dataSetSchema.containsPartitionDimensions(); + } + + protected void removeDateIfExist(SemanticParseInfo semanticParseInfo) { + String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); + Set removeFieldNames = new HashSet<>(); + removeFieldNames.add(TimeDimensionEnum.DAY.getChName()); + removeFieldNames.add(TimeDimensionEnum.WEEK.getChName()); + removeFieldNames.add(TimeDimensionEnum.MONTH.getChName()); + correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); + correctS2SQL = SqlRemoveHelper.removeSelect(correctS2SQL, removeFieldNames); + correctS2SQL = SqlRemoveHelper.removeGroupBy(correctS2SQL, removeFieldNames); + semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java index e06558f20..313ac73a3 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java @@ -36,6 +36,8 @@ public class SchemaCorrector extends BaseSemanticCorrector { @Override public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + removeDateFields(chatQueryContext, semanticParseInfo); + correctAggFunction(semanticParseInfo); replaceAlias(semanticParseInfo); @@ -47,6 +49,13 @@ public class SchemaCorrector extends BaseSemanticCorrector { correctFieldName(chatQueryContext, semanticParseInfo); } + private void removeDateFields(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + if (containsPartitionDimensions(chatQueryContext, semanticParseInfo)) { + return; + } + removeDateIfExist(semanticParseInfo); + } + private void correctAggFunction(SemanticParseInfo semanticParseInfo) { Map aggregateEnum = AggregateEnum.getAggregateEnum(); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java index c609e2cfa..e8ed5f124 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java @@ -4,12 +4,10 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo; import com.tencent.supersonic.common.jsqlparser.SqlAddHelper; import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper; -import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; @@ -19,10 +17,8 @@ import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.util.CollectionUtils; -import java.util.HashSet; import java.util.List; import java.util.Objects; -import java.util.Set; /** * Perform SQL corrections on the time in S2SQL. @@ -40,17 +36,6 @@ public class TimeCorrector extends BaseSemanticCorrector { addLowerBoundDate(semanticParseInfo); } - private void removeDateIfExist(SemanticParseInfo semanticParseInfo) { - String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); - Set removeFieldNames = new HashSet<>(); - removeFieldNames.add(TimeDimensionEnum.DAY.getChName()); - removeFieldNames.add(TimeDimensionEnum.WEEK.getChName()); - removeFieldNames.add(TimeDimensionEnum.MONTH.getChName()); - correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); - correctS2SQL = SqlRemoveHelper.removeGroupBy(correctS2SQL, removeFieldNames); - semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); - } - private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); List whereFields = SqlSelectHelper.getWhereFields(correctS2SQL); @@ -80,14 +65,6 @@ public class TimeCorrector extends BaseSemanticCorrector { semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } - private boolean containsPartitionDimensions(ChatQueryContext chatQueryContext, - SemanticParseInfo semanticParseInfo) { - Long dataSetId = semanticParseInfo.getDataSetId(); - SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); - DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); - return dataSetSchema.containsPartitionDimensions(); - } - private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);