mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
[improvement][common] Fix the issue of missing 'ORDER BY' clause generation (#1802)
This commit is contained in:
@@ -21,7 +21,6 @@ import java.util.List;
|
|||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SqlMergeWithUtils {
|
public class SqlMergeWithUtils {
|
||||||
|
|
||||||
public static String mergeWith(EngineType engineType, String sql, List<String> parentSqlList,
|
public static String mergeWith(EngineType engineType, String sql, List<String> parentSqlList,
|
||||||
List<String> parentWithNameList) throws SqlParseException {
|
List<String> parentWithNameList) throws SqlParseException {
|
||||||
SqlParser.Config parserConfig = Configuration.getParserConfig(engineType);
|
SqlParser.Config parserConfig = Configuration.getParserConfig(engineType);
|
||||||
@@ -51,11 +50,13 @@ public class SqlMergeWithUtils {
|
|||||||
withItemList.add(withItem);
|
withItemList.add(withItem);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the main SQL node contains a LIMIT clause
|
// Check if the main SQL node contains an ORDER BY or LIMIT clause
|
||||||
SqlNode limitNode = null;
|
SqlNode limitNode = null;
|
||||||
|
SqlNodeList orderByList = null;
|
||||||
if (sqlNode1 instanceof SqlOrderBy) {
|
if (sqlNode1 instanceof SqlOrderBy) {
|
||||||
SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode1;
|
SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode1;
|
||||||
limitNode = sqlOrderBy.fetch;
|
limitNode = sqlOrderBy.fetch;
|
||||||
|
orderByList = sqlOrderBy.orderList;
|
||||||
sqlNode1 = sqlOrderBy.query;
|
sqlNode1 = sqlOrderBy.query;
|
||||||
} else if (sqlNode1 instanceof SqlSelect) {
|
} else if (sqlNode1 instanceof SqlSelect) {
|
||||||
SqlSelect sqlSelect = (SqlSelect) sqlNode1;
|
SqlSelect sqlSelect = (SqlSelect) sqlNode1;
|
||||||
@@ -63,21 +64,23 @@ public class SqlMergeWithUtils {
|
|||||||
sqlSelect.setFetch(null);
|
sqlSelect.setFetch(null);
|
||||||
sqlNode1 = sqlSelect;
|
sqlNode1 = sqlSelect;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract existing WITH items from sqlNode1 if it is a SqlWith
|
// Extract existing WITH items from sqlNode1 if it is a SqlWith
|
||||||
if (sqlNode1 instanceof SqlWith) {
|
if (sqlNode1 instanceof SqlWith) {
|
||||||
SqlWith sqlWith = (SqlWith) sqlNode1;
|
SqlWith sqlWith = (SqlWith) sqlNode1;
|
||||||
withItemList.addAll(sqlWith.withList.getList());
|
withItemList.addAll(sqlWith.withList.getList());
|
||||||
sqlNode1 = sqlWith.body;
|
sqlNode1 = sqlWith.body;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new SqlWith node
|
// Create a new SqlWith node
|
||||||
SqlWith finalSqlNode = new SqlWith(SqlParserPos.ZERO,
|
SqlWith finalSqlNode = new SqlWith(SqlParserPos.ZERO,
|
||||||
new SqlNodeList(withItemList, SqlParserPos.ZERO), sqlNode1);
|
new SqlNodeList(withItemList, SqlParserPos.ZERO), sqlNode1);
|
||||||
|
|
||||||
// If there was a LIMIT clause, wrap the finalSqlNode in a SqlOrderBy with the LIMIT
|
// If there was an ORDER BY or LIMIT clause, wrap the finalSqlNode in a SqlOrderBy
|
||||||
SqlNode resultNode = finalSqlNode;
|
SqlNode resultNode = finalSqlNode;
|
||||||
if (limitNode != null) {
|
if (orderByList != null || limitNode != null) {
|
||||||
resultNode = new SqlOrderBy(SqlParserPos.ZERO, finalSqlNode, SqlNodeList.EMPTY, null,
|
resultNode = new SqlOrderBy(SqlParserPos.ZERO, finalSqlNode,
|
||||||
limitNode);
|
orderByList != null ? orderByList : SqlNodeList.EMPTY, null, limitNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Custom SqlPrettyWriter configuration to avoid quoting identifiers
|
// Custom SqlPrettyWriter configuration to avoid quoting identifiers
|
||||||
|
|||||||
@@ -44,20 +44,18 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
|||||||
new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string",
|
new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string",
|
||||||
MODULE_NAME, null, getDatabaseNameDependency());
|
MODULE_NAME, null, getDatabaseNameDependency());
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_STORE_POST =
|
public static final Parameter EMBEDDING_STORE_POST = new Parameter("s2.embedding.store.post",
|
||||||
new Parameter("s2.embedding.store.post", "", "端口", "", "number", MODULE_NAME, null,
|
"", "端口", "", "number", MODULE_NAME, null, getPostDependency());
|
||||||
getPostDependency());
|
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_STORE_USER =
|
public static final Parameter EMBEDDING_STORE_USER = new Parameter("s2.embedding.store.user",
|
||||||
new Parameter("s2.embedding.store.user", "", "用户名", "", "string", MODULE_NAME, null,
|
"", "用户名", "", "string", MODULE_NAME, null, getUserDependency());
|
||||||
getUserDependency());
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Parameter> getSysParameters() {
|
public List<Parameter> getSysParameters() {
|
||||||
return Lists.newArrayList(EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL,
|
return Lists.newArrayList(EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL,
|
||||||
EMBEDDING_STORE_POST, EMBEDDING_STORE_USER,
|
EMBEDDING_STORE_POST, EMBEDDING_STORE_USER, EMBEDDING_STORE_API_KEY,
|
||||||
EMBEDDING_STORE_API_KEY, EMBEDDING_STORE_DATABASE_NAME,
|
EMBEDDING_STORE_DATABASE_NAME, EMBEDDING_STORE_PERSIST_PATH,
|
||||||
EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION);
|
EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION);
|
||||||
}
|
}
|
||||||
|
|
||||||
public EmbeddingStoreConfig convert() {
|
public EmbeddingStoreConfig convert() {
|
||||||
@@ -83,16 +81,14 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
|||||||
|
|
||||||
private static ArrayList<String> getCandidateValues() {
|
private static ArrayList<String> getCandidateValues() {
|
||||||
return Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
|
return Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
|
||||||
EmbeddingStoreType.MILVUS.name(),
|
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name(),
|
||||||
EmbeddingStoreType.CHROMA.name(),
|
|
||||||
EmbeddingStoreType.PGVECTOR.name());
|
EmbeddingStoreType.PGVECTOR.name());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||||
EmbeddingStoreType.CHROMA.name(),
|
EmbeddingStoreType.CHROMA.name(), EmbeddingStoreType.PGVECTOR.name()),
|
||||||
EmbeddingStoreType.PGVECTOR.name()),
|
|
||||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
|
||||||
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000",
|
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000",
|
||||||
EmbeddingStoreType.PGVECTOR.name(), "127.0.0.1"));
|
EmbeddingStoreType.PGVECTOR.name(), "127.0.0.1"));
|
||||||
@@ -100,7 +96,8 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
|||||||
|
|
||||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()),
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||||
|
EmbeddingStoreType.PGVECTOR.name()),
|
||||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO,
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO,
|
||||||
EmbeddingStoreType.PGVECTOR.name(), DEMO));
|
EmbeddingStoreType.PGVECTOR.name(), DEMO));
|
||||||
}
|
}
|
||||||
@@ -113,28 +110,28 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
|||||||
|
|
||||||
private static List<Parameter.Dependency> getDimensionDependency() {
|
private static List<Parameter.Dependency> getDimensionDependency() {
|
||||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()),
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||||
|
EmbeddingStoreType.PGVECTOR.name()),
|
||||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384",
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384",
|
||||||
EmbeddingStoreType.PGVECTOR.name(), "768"));
|
EmbeddingStoreType.PGVECTOR.name(), "768"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getDatabaseNameDependency() {
|
private static List<Parameter.Dependency> getDatabaseNameDependency() {
|
||||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()),
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||||
|
EmbeddingStoreType.PGVECTOR.name()),
|
||||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "",
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "",
|
||||||
EmbeddingStoreType.PGVECTOR.name(), "postgres"));
|
EmbeddingStoreType.PGVECTOR.name(), "postgres"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getPostDependency() {
|
private static List<Parameter.Dependency> getPostDependency() {
|
||||||
return getDependency(
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
EMBEDDING_STORE_PROVIDER.getName(),
|
|
||||||
Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()),
|
Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()),
|
||||||
ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "54333"));
|
ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "54333"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getUserDependency() {
|
private static List<Parameter.Dependency> getUserDependency() {
|
||||||
return getDependency(
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
EMBEDDING_STORE_PROVIDER.getName(),
|
|
||||||
Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()),
|
Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()),
|
||||||
ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "pgvector"));
|
ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "pgvector"));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,15 +32,10 @@ public class PgvectorEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
|
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
|
||||||
return PgVectorEmbeddingStore.builder()
|
return PgVectorEmbeddingStore.builder().host(storeProperties.getHost())
|
||||||
.host(storeProperties.getHost())
|
.port(storeProperties.getPort()).database(storeProperties.getDatabase())
|
||||||
.port(storeProperties.getPort())
|
.user(storeProperties.getUser()).password(storeProperties.getPassword())
|
||||||
.database(storeProperties.getDatabase())
|
.table(collectionName).dimension(storeProperties.getDimension()).build();
|
||||||
.user(storeProperties.getUser())
|
|
||||||
.password(storeProperties.getPassword())
|
|
||||||
.table(collectionName)
|
|
||||||
.dimension(storeProperties.getDimension())
|
|
||||||
.build();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,5 +12,6 @@ public class Properties {
|
|||||||
|
|
||||||
static final String PREFIX = "langchain4j.pgvector";
|
static final String PREFIX = "langchain4j.pgvector";
|
||||||
|
|
||||||
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
|
@NestedConfigurationProperty
|
||||||
|
EmbeddingStoreProperties embeddingStore;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -35,9 +35,9 @@ public class EmbeddingStoreFactoryProvider {
|
|||||||
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||||
storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig));
|
storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig));
|
||||||
}
|
}
|
||||||
if (EmbeddingStoreType.PGVECTOR.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
if (EmbeddingStoreType.PGVECTOR.name()
|
||||||
return factoryMap.computeIfAbsent(
|
.equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||||
embeddingStoreConfig,
|
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||||
storeConfig -> new PgvectorEmbeddingStoreFactory(storeConfig));
|
storeConfig -> new PgvectorEmbeddingStoreFactory(storeConfig));
|
||||||
}
|
}
|
||||||
if (EmbeddingStoreType.IN_MEMORY.name()
|
if (EmbeddingStoreType.IN_MEMORY.name()
|
||||||
|
|||||||
@@ -136,6 +136,32 @@ class SqlWithMergerTest {
|
|||||||
+ "WHERE 总访问次数 > 100");
|
+ "WHERE 总访问次数 > 100");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void test6() throws SqlParseException {
|
||||||
|
|
||||||
|
String sql1 =
|
||||||
|
"SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10";
|
||||||
|
|
||||||
|
String sql2 =
|
||||||
|
"SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n"
|
||||||
|
+ "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n"
|
||||||
|
+ "`s2_user_department`) AS `t2`\n"
|
||||||
|
+ "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n"
|
||||||
|
+ "FROM\n"
|
||||||
|
+ "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`";
|
||||||
|
|
||||||
|
String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1,
|
||||||
|
Collections.singletonList(sql2), Collections.singletonList("t_1"));
|
||||||
|
|
||||||
|
|
||||||
|
Assert.assertEquals(format(mergeSql),
|
||||||
|
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` FROM "
|
||||||
|
+ "(SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`,"
|
||||||
|
+ " `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) "
|
||||||
|
+ "SELECT COUNT(*) FROM Department INNER JOIN Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10");
|
||||||
|
}
|
||||||
|
|
||||||
private static String format(String mergeSql) {
|
private static String format(String mergeSql) {
|
||||||
mergeSql = mergeSql.replace("\r\n", "\n");
|
mergeSql = mergeSql.replace("\r\n", "\n");
|
||||||
// Remove extra spaces and newlines
|
// Remove extra spaces and newlines
|
||||||
|
|||||||
Reference in New Issue
Block a user