[improvement][common] Fix the issue of missing 'ORDER BY' clause generation (#1802)

This commit is contained in:
lexluo09
2024-10-15 10:22:22 +08:00
committed by GitHub
parent 1ef642d0dd
commit d7f301064a
6 changed files with 61 additions and 39 deletions

View File

@@ -21,7 +21,6 @@ import java.util.List;
@Slf4j
public class SqlMergeWithUtils {
public static String mergeWith(EngineType engineType, String sql, List<String> parentSqlList,
List<String> parentWithNameList) throws SqlParseException {
SqlParser.Config parserConfig = Configuration.getParserConfig(engineType);
@@ -51,11 +50,13 @@ public class SqlMergeWithUtils {
withItemList.add(withItem);
}
// Check if the main SQL node contains a LIMIT clause
// Check if the main SQL node contains an ORDER BY or LIMIT clause
SqlNode limitNode = null;
SqlNodeList orderByList = null;
if (sqlNode1 instanceof SqlOrderBy) {
SqlOrderBy sqlOrderBy = (SqlOrderBy) sqlNode1;
limitNode = sqlOrderBy.fetch;
orderByList = sqlOrderBy.orderList;
sqlNode1 = sqlOrderBy.query;
} else if (sqlNode1 instanceof SqlSelect) {
SqlSelect sqlSelect = (SqlSelect) sqlNode1;
@@ -63,21 +64,23 @@ public class SqlMergeWithUtils {
sqlSelect.setFetch(null);
sqlNode1 = sqlSelect;
}
// Extract existing WITH items from sqlNode1 if it is a SqlWith
if (sqlNode1 instanceof SqlWith) {
SqlWith sqlWith = (SqlWith) sqlNode1;
withItemList.addAll(sqlWith.withList.getList());
sqlNode1 = sqlWith.body;
}
// Create a new SqlWith node
SqlWith finalSqlNode = new SqlWith(SqlParserPos.ZERO,
new SqlNodeList(withItemList, SqlParserPos.ZERO), sqlNode1);
// If there was a LIMIT clause, wrap the finalSqlNode in a SqlOrderBy with the LIMIT
// If there was an ORDER BY or LIMIT clause, wrap the finalSqlNode in a SqlOrderBy
SqlNode resultNode = finalSqlNode;
if (limitNode != null) {
resultNode = new SqlOrderBy(SqlParserPos.ZERO, finalSqlNode, SqlNodeList.EMPTY, null,
limitNode);
if (orderByList != null || limitNode != null) {
resultNode = new SqlOrderBy(SqlParserPos.ZERO, finalSqlNode,
orderByList != null ? orderByList : SqlNodeList.EMPTY, null, limitNode);
}
// Custom SqlPrettyWriter configuration to avoid quoting identifiers

View File

@@ -44,20 +44,18 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string",
MODULE_NAME, null, getDatabaseNameDependency());
public static final Parameter EMBEDDING_STORE_POST =
new Parameter("s2.embedding.store.post", "", "端口", "", "number", MODULE_NAME, null,
getPostDependency());
public static final Parameter EMBEDDING_STORE_POST = new Parameter("s2.embedding.store.post",
"", "端口", "", "number", MODULE_NAME, null, getPostDependency());
public static final Parameter EMBEDDING_STORE_USER =
new Parameter("s2.embedding.store.user", "", "用户名", "", "string", MODULE_NAME, null,
getUserDependency());
public static final Parameter EMBEDDING_STORE_USER = new Parameter("s2.embedding.store.user",
"", "用户名", "", "string", MODULE_NAME, null, getUserDependency());
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL,
EMBEDDING_STORE_POST, EMBEDDING_STORE_USER,
EMBEDDING_STORE_API_KEY, EMBEDDING_STORE_DATABASE_NAME,
EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION);
EMBEDDING_STORE_POST, EMBEDDING_STORE_USER, EMBEDDING_STORE_API_KEY,
EMBEDDING_STORE_DATABASE_NAME, EMBEDDING_STORE_PERSIST_PATH,
EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION);
}
public EmbeddingStoreConfig convert() {
@@ -83,16 +81,14 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
private static ArrayList<String> getCandidateValues() {
return Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name(),
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name(),
EmbeddingStoreType.PGVECTOR.name());
}
private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name(),
EmbeddingStoreType.PGVECTOR.name()),
EmbeddingStoreType.CHROMA.name(), EmbeddingStoreType.PGVECTOR.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000",
EmbeddingStoreType.PGVECTOR.name(), "127.0.0.1"));
@@ -100,7 +96,8 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.PGVECTOR.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO,
EmbeddingStoreType.PGVECTOR.name(), DEMO));
}
@@ -113,28 +110,28 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
private static List<Parameter.Dependency> getDimensionDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.PGVECTOR.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384",
EmbeddingStoreType.PGVECTOR.name(), "768"));
}
private static List<Parameter.Dependency> getDatabaseNameDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.PGVECTOR.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "",
EmbeddingStoreType.PGVECTOR.name(), "postgres"));
}
private static List<Parameter.Dependency> getPostDependency() {
return getDependency(
EMBEDDING_STORE_PROVIDER.getName(),
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()),
ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "54333"));
}
private static List<Parameter.Dependency> getUserDependency() {
return getDependency(
EMBEDDING_STORE_PROVIDER.getName(),
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()),
ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "pgvector"));
}

View File

@@ -32,15 +32,10 @@ public class PgvectorEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
@Override
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
return PgVectorEmbeddingStore.builder()
.host(storeProperties.getHost())
.port(storeProperties.getPort())
.database(storeProperties.getDatabase())
.user(storeProperties.getUser())
.password(storeProperties.getPassword())
.table(collectionName)
.dimension(storeProperties.getDimension())
.build();
return PgVectorEmbeddingStore.builder().host(storeProperties.getHost())
.port(storeProperties.getPort()).database(storeProperties.getDatabase())
.user(storeProperties.getUser()).password(storeProperties.getPassword())
.table(collectionName).dimension(storeProperties.getDimension()).build();
}
}

View File

@@ -12,5 +12,6 @@ public class Properties {
static final String PREFIX = "langchain4j.pgvector";
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
@NestedConfigurationProperty
EmbeddingStoreProperties embeddingStore;
}

View File

@@ -35,9 +35,9 @@ public class EmbeddingStoreFactoryProvider {
return factoryMap.computeIfAbsent(embeddingStoreConfig,
storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig));
}
if (EmbeddingStoreType.PGVECTOR.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(
embeddingStoreConfig,
if (EmbeddingStoreType.PGVECTOR.name()
.equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(embeddingStoreConfig,
storeConfig -> new PgvectorEmbeddingStoreFactory(storeConfig));
}
if (EmbeddingStoreType.IN_MEMORY.name()

View File

@@ -136,6 +136,32 @@ class SqlWithMergerTest {
+ "WHERE 总访问次数 > 100");
}
@Test
void test6() throws SqlParseException {
String sql1 =
"SELECT COUNT(*) FROM Department join Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10";
String sql2 =
"SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n"
+ "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n"
+ "`s2_user_department`) AS `t2`\n"
+ "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n"
+ "FROM\n"
+ "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`";
String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1,
Collections.singletonList(sql2), Collections.singletonList("t_1"));
Assert.assertEquals(format(mergeSql),
"WITH t_1 AS (SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv` FROM "
+ "(SELECT `user_name`, `department` FROM `s2_user_department`) AS `t2` LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`,"
+ " `imp_date` AS `sys_imp_date`, `user_name` FROM `s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`) "
+ "SELECT COUNT(*) FROM Department INNER JOIN Visits WHERE 总访问次数 > 100 ORDER BY 总访问次数 LIMIT 10");
}
private static String format(String mergeSql) {
mergeSql = mergeSql.replace("\r\n", "\n");
// Remove extra spaces and newlines