From d7f301064abe819de3fc0c1df77555b5fddd28c4 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Tue, 15 Oct 2024 10:22:22 +0800 Subject: [PATCH] [improvement][common] Fix the issue of missing 'ORDER BY' clause generation (#1802) --- .../common/calcite/SqlMergeWithUtils.java | 15 +++++--- .../config/EmbeddingStoreParameterConfig.java | 37 +++++++++---------- .../spring/PgvectorEmbeddingStoreFactory.java | 13 ++----- .../pgvector/spring/Properties.java | 3 +- .../EmbeddingStoreFactoryProvider.java | 6 +-- .../common/calcite/SqlWithMergerTest.java | 26 +++++++++++++ 6 files changed, 61 insertions(+), 39 deletions(-) diff --git a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java index ccf5fb701..79ca83dc0 100644 --- a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java @@ -21,7 +21,6 @@ import java.util.List; @Slf4j public class SqlMergeWithUtils { - public static String mergeWith(EngineType engineType, String sql, List parentSqlList, List parentWithNameList) throws SqlParseException { SqlParser.Config parserConfig = Configuration.getParserConfig(engineType); @@ -51,11 +50,13 @@ public class SqlMergeWithUtils { withItemList.add(withItem); } - // Check if the main SQL node contains a LIMIT clause + // Check if the main SQL node contains an ORDER BY or LIMIT clause SqlNode limitNode = null; + SqlNodeList orderByList = null; if (sqlNode1 instanceof SqlOrderBy) { SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode1; limitNode = sqlOrderBy.fetch; + orderByList = sqlOrderBy.orderList; sqlNode1 = sqlOrderBy.query; } else if (sqlNode1 instanceof SqlSelect) { SqlSelect sqlSelect = (SqlSelect) sqlNode1; @@ -63,21 +64,23 @@ public class SqlMergeWithUtils { sqlSelect.setFetch(null); sqlNode1 = sqlSelect; } + // Extract existing WITH items from sqlNode1 if it is a SqlWith if (sqlNode1 instanceof SqlWith) { SqlWith sqlWith = (SqlWith) sqlNode1; withItemList.addAll(sqlWith.withList.getList()); sqlNode1 = sqlWith.body; } + // Create a new SqlWith node SqlWith finalSqlNode = new SqlWith(SqlParserPos.ZERO, new SqlNodeList(withItemList, SqlParserPos.ZERO), sqlNode1); - // If there was a LIMIT clause, wrap the finalSqlNode in a SqlOrderBy with the LIMIT + // If there was an ORDER BY or LIMIT clause, wrap the finalSqlNode in a SqlOrderBy SqlNode resultNode = finalSqlNode; - if (limitNode != null) { - resultNode = new SqlOrderBy(SqlParserPos.ZERO, finalSqlNode, SqlNodeList.EMPTY, null, - limitNode); + if (orderByList != null || limitNode != null) { + resultNode = new SqlOrderBy(SqlParserPos.ZERO, finalSqlNode, + orderByList != null ? orderByList : SqlNodeList.EMPTY, null, limitNode); } // Custom SqlPrettyWriter configuration to avoid quoting identifiers diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java index 33f2bc14a..8a5da2131 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java @@ -44,20 +44,18 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig { new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string", MODULE_NAME, null, getDatabaseNameDependency()); - public static final Parameter EMBEDDING_STORE_POST = - new Parameter("s2.embedding.store.post", "", "端口", "", "number", MODULE_NAME, null, - getPostDependency()); + public static final Parameter EMBEDDING_STORE_POST = new Parameter("s2.embedding.store.post", + "", "端口", "", "number", MODULE_NAME, null, getPostDependency()); - public static final Parameter EMBEDDING_STORE_USER = - new Parameter("s2.embedding.store.user", "", "用户名", "", "string", MODULE_NAME, null, - getUserDependency()); + public static final Parameter EMBEDDING_STORE_USER = new Parameter("s2.embedding.store.user", + "", "用户名", "", "string", MODULE_NAME, null, getUserDependency()); @Override public List getSysParameters() { return Lists.newArrayList(EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL, - EMBEDDING_STORE_POST, EMBEDDING_STORE_USER, - EMBEDDING_STORE_API_KEY, EMBEDDING_STORE_DATABASE_NAME, - EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION); + EMBEDDING_STORE_POST, EMBEDDING_STORE_USER, EMBEDDING_STORE_API_KEY, + EMBEDDING_STORE_DATABASE_NAME, EMBEDDING_STORE_PERSIST_PATH, + EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION); } public EmbeddingStoreConfig convert() { @@ -83,16 +81,14 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig { private static ArrayList getCandidateValues() { return Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(), - EmbeddingStoreType.MILVUS.name(), - EmbeddingStoreType.CHROMA.name(), + EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name(), EmbeddingStoreType.PGVECTOR.name()); } private static List getBaseUrlDependency() { return getDependency(EMBEDDING_STORE_PROVIDER.getName(), Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), - EmbeddingStoreType.CHROMA.name(), - EmbeddingStoreType.PGVECTOR.name()), + EmbeddingStoreType.CHROMA.name(), EmbeddingStoreType.PGVECTOR.name()), ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "http://localhost:19530", EmbeddingStoreType.CHROMA.name(), "http://localhost:8000", EmbeddingStoreType.PGVECTOR.name(), "127.0.0.1")); @@ -100,7 +96,8 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig { private static List getApiKeyDependency() { return getDependency(EMBEDDING_STORE_PROVIDER.getName(), - Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()), + Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), + EmbeddingStoreType.PGVECTOR.name()), ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO, EmbeddingStoreType.PGVECTOR.name(), DEMO)); } @@ -113,28 +110,28 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig { private static List getDimensionDependency() { return getDependency(EMBEDDING_STORE_PROVIDER.getName(), - Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()), + Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), + EmbeddingStoreType.PGVECTOR.name()), ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384", EmbeddingStoreType.PGVECTOR.name(), "768")); } private static List getDatabaseNameDependency() { return getDependency(EMBEDDING_STORE_PROVIDER.getName(), - Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()), + Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), + EmbeddingStoreType.PGVECTOR.name()), ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "", EmbeddingStoreType.PGVECTOR.name(), "postgres")); } private static List getPostDependency() { - return getDependency( - EMBEDDING_STORE_PROVIDER.getName(), + return getDependency(EMBEDDING_STORE_PROVIDER.getName(), Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()), ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "54333")); } private static List getUserDependency() { - return getDependency( - EMBEDDING_STORE_PROVIDER.getName(), + return getDependency(EMBEDDING_STORE_PROVIDER.getName(), Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()), ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "pgvector")); } diff --git a/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorEmbeddingStoreFactory.java index e0e8b758d..cde961c8c 100644 --- a/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorEmbeddingStoreFactory.java @@ -32,15 +32,10 @@ public class PgvectorEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { @Override public EmbeddingStore createEmbeddingStore(String collectionName) { - return PgVectorEmbeddingStore.builder() - .host(storeProperties.getHost()) - .port(storeProperties.getPort()) - .database(storeProperties.getDatabase()) - .user(storeProperties.getUser()) - .password(storeProperties.getPassword()) - .table(collectionName) - .dimension(storeProperties.getDimension()) - .build(); + return PgVectorEmbeddingStore.builder().host(storeProperties.getHost()) + .port(storeProperties.getPort()).database(storeProperties.getDatabase()) + .user(storeProperties.getUser()).password(storeProperties.getPassword()) + .table(collectionName).dimension(storeProperties.getDimension()).build(); } } diff --git a/common/src/main/java/dev/langchain4j/pgvector/spring/Properties.java b/common/src/main/java/dev/langchain4j/pgvector/spring/Properties.java index 6c9e1473e..25ac4ea33 100644 --- a/common/src/main/java/dev/langchain4j/pgvector/spring/Properties.java +++ b/common/src/main/java/dev/langchain4j/pgvector/spring/Properties.java @@ -12,5 +12,6 @@ public class Properties { static final String PREFIX = "langchain4j.pgvector"; - @NestedConfigurationProperty EmbeddingStoreProperties embeddingStore; + @NestedConfigurationProperty + EmbeddingStoreProperties embeddingStore; } diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java index 4113bd484..828a0e8fc 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java @@ -35,9 +35,9 @@ public class EmbeddingStoreFactoryProvider { return factoryMap.computeIfAbsent(embeddingStoreConfig, storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig)); } - if (EmbeddingStoreType.PGVECTOR.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) { - return factoryMap.computeIfAbsent( - embeddingStoreConfig, + if (EmbeddingStoreType.PGVECTOR.name() + .equalsIgnoreCase(embeddingStoreConfig.getProvider())) { + return factoryMap.computeIfAbsent(embeddingStoreConfig, storeConfig -> new PgvectorEmbeddingStoreFactory(storeConfig)); } if (EmbeddingStoreType.IN_MEMORY.name() diff --git a/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java b/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java index 53e8eef13..5da8fc84b 100644 --- a/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java @@ -136,6 +136,32 @@ class SqlWithMergerTest { + "WHERE 总访问次数 > 100"); } + + @Test + void test6() throws SqlParseException { + + String sql1 = + "SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10"; + + String sql2 = + "SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n" + + "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n" + + "`s2_user_department`) AS `t2`\n" + + "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n" + + "FROM\n" + + "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`"; + + String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1, + Collections.singletonList(sql2), Collections.singletonList("t_1")); + + + Assert.assertEquals(format(mergeSql), + "WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` FROM " + + "(SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`," + + " `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) " + + "SELECT COUNT(*) FROM Department INNER JOIN Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10"); + } + private static String format(String mergeSql) { mergeSql = mergeSql.replace("\r\n", "\n"); // Remove extra spaces and newlines