mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement][common] Fix the issue of missing 'ORDER BY' clause generation (#1802)
This commit is contained in:
@@ -21,7 +21,6 @@ import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class SqlMergeWithUtils {
|
||||
|
||||
public static String mergeWith(EngineType engineType, String sql, List<String> parentSqlList,
|
||||
List<String> parentWithNameList) throws SqlParseException {
|
||||
SqlParser.Config parserConfig = Configuration.getParserConfig(engineType);
|
||||
@@ -51,11 +50,13 @@ public class SqlMergeWithUtils {
|
||||
withItemList.add(withItem);
|
||||
}
|
||||
|
||||
// Check if the main SQL node contains a LIMIT clause
|
||||
// Check if the main SQL node contains an ORDER BY or LIMIT clause
|
||||
SqlNode limitNode = null;
|
||||
SqlNodeList orderByList = null;
|
||||
if (sqlNode1 instanceof SqlOrderBy) {
|
||||
SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode1;
|
||||
limitNode = sqlOrderBy.fetch;
|
||||
orderByList = sqlOrderBy.orderList;
|
||||
sqlNode1 = sqlOrderBy.query;
|
||||
} else if (sqlNode1 instanceof SqlSelect) {
|
||||
SqlSelect sqlSelect = (SqlSelect) sqlNode1;
|
||||
@@ -63,21 +64,23 @@ public class SqlMergeWithUtils {
|
||||
sqlSelect.setFetch(null);
|
||||
sqlNode1 = sqlSelect;
|
||||
}
|
||||
|
||||
// Extract existing WITH items from sqlNode1 if it is a SqlWith
|
||||
if (sqlNode1 instanceof SqlWith) {
|
||||
SqlWith sqlWith = (SqlWith) sqlNode1;
|
||||
withItemList.addAll(sqlWith.withList.getList());
|
||||
sqlNode1 = sqlWith.body;
|
||||
}
|
||||
|
||||
// Create a new SqlWith node
|
||||
SqlWith finalSqlNode = new SqlWith(SqlParserPos.ZERO,
|
||||
new SqlNodeList(withItemList, SqlParserPos.ZERO), sqlNode1);
|
||||
|
||||
// If there was a LIMIT clause, wrap the finalSqlNode in a SqlOrderBy with the LIMIT
|
||||
// If there was an ORDER BY or LIMIT clause, wrap the finalSqlNode in a SqlOrderBy
|
||||
SqlNode resultNode = finalSqlNode;
|
||||
if (limitNode != null) {
|
||||
resultNode = new SqlOrderBy(SqlParserPos.ZERO, finalSqlNode, SqlNodeList.EMPTY, null,
|
||||
limitNode);
|
||||
if (orderByList != null || limitNode != null) {
|
||||
resultNode = new SqlOrderBy(SqlParserPos.ZERO, finalSqlNode,
|
||||
orderByList != null ? orderByList : SqlNodeList.EMPTY, null, limitNode);
|
||||
}
|
||||
|
||||
// Custom SqlPrettyWriter configuration to avoid quoting identifiers
|
||||
|
||||
@@ -44,20 +44,18 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||
new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string",
|
||||
MODULE_NAME, null, getDatabaseNameDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_POST =
|
||||
new Parameter("s2.embedding.store.post", "", "端口", "", "number", MODULE_NAME, null,
|
||||
getPostDependency());
|
||||
public static final Parameter EMBEDDING_STORE_POST = new Parameter("s2.embedding.store.post",
|
||||
"", "端口", "", "number", MODULE_NAME, null, getPostDependency());
|
||||
|
||||
public static final Parameter EMBEDDING_STORE_USER =
|
||||
new Parameter("s2.embedding.store.user", "", "用户名", "", "string", MODULE_NAME, null,
|
||||
getUserDependency());
|
||||
public static final Parameter EMBEDDING_STORE_USER = new Parameter("s2.embedding.store.user",
|
||||
"", "用户名", "", "string", MODULE_NAME, null, getUserDependency());
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL,
|
||||
EMBEDDING_STORE_POST, EMBEDDING_STORE_USER,
|
||||
EMBEDDING_STORE_API_KEY, EMBEDDING_STORE_DATABASE_NAME,
|
||||
EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION);
|
||||
EMBEDDING_STORE_POST, EMBEDDING_STORE_USER, EMBEDDING_STORE_API_KEY,
|
||||
EMBEDDING_STORE_DATABASE_NAME, EMBEDDING_STORE_PERSIST_PATH,
|
||||
EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION);
|
||||
}
|
||||
|
||||
public EmbeddingStoreConfig convert() {
|
||||
@@ -83,16 +81,14 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||
|
||||
private static ArrayList<String> getCandidateValues() {
|
||||
return Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
|
||||
EmbeddingStoreType.MILVUS.name(),
|
||||
EmbeddingStoreType.CHROMA.name(),
|
||||
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name(),
|
||||
EmbeddingStoreType.PGVECTOR.name());
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||
EmbeddingStoreType.CHROMA.name(),
|
||||
EmbeddingStoreType.PGVECTOR.name()),
|
||||
EmbeddingStoreType.CHROMA.name(), EmbeddingStoreType.PGVECTOR.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
|
||||
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000",
|
||||
EmbeddingStoreType.PGVECTOR.name(), "127.0.0.1"));
|
||||
@@ -100,7 +96,8 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()),
|
||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||
EmbeddingStoreType.PGVECTOR.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO,
|
||||
EmbeddingStoreType.PGVECTOR.name(), DEMO));
|
||||
}
|
||||
@@ -113,28 +110,28 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
||||
|
||||
private static List<Parameter.Dependency> getDimensionDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()),
|
||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||
EmbeddingStoreType.PGVECTOR.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384",
|
||||
EmbeddingStoreType.PGVECTOR.name(), "768"));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getDatabaseNameDependency() {
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()),
|
||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||
EmbeddingStoreType.PGVECTOR.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "",
|
||||
EmbeddingStoreType.PGVECTOR.name(), "postgres"));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getPostDependency() {
|
||||
return getDependency(
|
||||
EMBEDDING_STORE_PROVIDER.getName(),
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "54333"));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getUserDependency() {
|
||||
return getDependency(
|
||||
EMBEDDING_STORE_PROVIDER.getName(),
|
||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||
Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()),
|
||||
ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "pgvector"));
|
||||
}
|
||||
|
||||
@@ -32,15 +32,10 @@ public class PgvectorEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
|
||||
@Override
|
||||
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
|
||||
return PgVectorEmbeddingStore.builder()
|
||||
.host(storeProperties.getHost())
|
||||
.port(storeProperties.getPort())
|
||||
.database(storeProperties.getDatabase())
|
||||
.user(storeProperties.getUser())
|
||||
.password(storeProperties.getPassword())
|
||||
.table(collectionName)
|
||||
.dimension(storeProperties.getDimension())
|
||||
.build();
|
||||
return PgVectorEmbeddingStore.builder().host(storeProperties.getHost())
|
||||
.port(storeProperties.getPort()).database(storeProperties.getDatabase())
|
||||
.user(storeProperties.getUser()).password(storeProperties.getPassword())
|
||||
.table(collectionName).dimension(storeProperties.getDimension()).build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -12,5 +12,6 @@ public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.pgvector";
|
||||
|
||||
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingStoreProperties embeddingStore;
|
||||
}
|
||||
|
||||
@@ -35,9 +35,9 @@ public class EmbeddingStoreFactoryProvider {
|
||||
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||
storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig));
|
||||
}
|
||||
if (EmbeddingStoreType.PGVECTOR.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||
return factoryMap.computeIfAbsent(
|
||||
embeddingStoreConfig,
|
||||
if (EmbeddingStoreType.PGVECTOR.name()
|
||||
.equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||
storeConfig -> new PgvectorEmbeddingStoreFactory(storeConfig));
|
||||
}
|
||||
if (EmbeddingStoreType.IN_MEMORY.name()
|
||||
|
||||
@@ -136,6 +136,32 @@ class SqlWithMergerTest {
|
||||
+ "WHERE 总访问次数 > 100");
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void test6() throws SqlParseException {
|
||||
|
||||
String sql1 =
|
||||
"SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10";
|
||||
|
||||
String sql2 =
|
||||
"SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n"
|
||||
+ "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n"
|
||||
+ "`s2_user_department`) AS `t2`\n"
|
||||
+ "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n"
|
||||
+ "FROM\n"
|
||||
+ "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`";
|
||||
|
||||
String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1,
|
||||
Collections.singletonList(sql2), Collections.singletonList("t_1"));
|
||||
|
||||
|
||||
Assert.assertEquals(format(mergeSql),
|
||||
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` FROM "
|
||||
+ "(SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`,"
|
||||
+ " `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) "
|
||||
+ "SELECT COUNT(*) FROM Department INNER JOIN Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10");
|
||||
}
|
||||
|
||||
private static String format(String mergeSql) {
|
||||
mergeSql = mergeSql.replace("\r\n", "\n");
|
||||
// Remove extra spaces and newlines
|
||||
|
||||
Reference in New Issue
Block a user