From 1ef642d0dd7c8f6aa47b017226097c1778e80fd3 Mon Sep 17 00:00:00 2001
From: Zhengyang Jia <67008026+littleforestjia@users.noreply.github.com>
Date: Tue, 15 Oct 2024 09:16:32 +0800
Subject: [PATCH] [improvement][common]Add pgvector vector library adaptation.
(#1800)
---
common/pom.xml | 4 ++
.../config/EmbeddingStoreParameterConfig.java | 57 +++++++++++++++----
.../common/pojo/EmbeddingStoreConfig.java | 2 +
.../spring/EmbeddingStoreProperties.java | 23 ++++++++
.../pgvector/spring/PgvectorAutoConfig.java | 20 +++++++
.../spring/PgvectorEmbeddingStoreFactory.java | 46 +++++++++++++++
.../pgvector/spring/Properties.java | 16 ++++++
.../EmbeddingStoreFactoryProvider.java | 6 ++
.../store/embedding/EmbeddingStoreType.java | 2 +-
pom.xml | 5 ++
10 files changed, 169 insertions(+), 12 deletions(-)
create mode 100644 common/src/main/java/dev/langchain4j/pgvector/spring/EmbeddingStoreProperties.java
create mode 100644 common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorAutoConfig.java
create mode 100644 common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorEmbeddingStoreFactory.java
create mode 100644 common/src/main/java/dev/langchain4j/pgvector/spring/Properties.java
diff --git a/common/pom.xml b/common/pom.xml
index b8ecd2ddf..27b847cc9 100644
--- a/common/pom.xml
+++ b/common/pom.xml
@@ -174,6 +174,10 @@
dev.langchain4j
langchain4j-milvus
+
+ dev.langchain4j
+ langchain4j-pgvector
+
dev.langchain4j
langchain4j-azure-open-ai
diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java
index 4918d770b..33f2bc14a 100644
--- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java
+++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java
@@ -19,7 +19,7 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_STORE_PROVIDER = new Parameter(
"s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), "向量库类型",
- "目前支持三种类型:IN_MEMORY、MILVUS、CHROMA", "list", MODULE_NAME, getCandidateValues());
+ "目前支持四种类型:IN_MEMORY、MILVUS、CHROMA、PGVECTOR", "list", MODULE_NAME, getCandidateValues());
public static final Parameter EMBEDDING_STORE_BASE_URL =
new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", MODULE_NAME,
@@ -44,9 +44,18 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string",
MODULE_NAME, null, getDatabaseNameDependency());
+ public static final Parameter EMBEDDING_STORE_POST =
+ new Parameter("s2.embedding.store.post", "", "端口", "", "number", MODULE_NAME, null,
+ getPostDependency());
+
+ public static final Parameter EMBEDDING_STORE_USER =
+ new Parameter("s2.embedding.store.user", "", "用户名", "", "string", MODULE_NAME, null,
+ getUserDependency());
+
@Override
public List getSysParameters() {
return Lists.newArrayList(EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL,
+ EMBEDDING_STORE_POST, EMBEDDING_STORE_USER,
EMBEDDING_STORE_API_KEY, EMBEDDING_STORE_DATABASE_NAME,
EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION);
}
@@ -62,28 +71,38 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) {
dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION));
}
+ Integer port = null;
+ if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_POST))) {
+ port = Integer.valueOf(getParameterValue(EMBEDDING_STORE_POST));
+ }
+ String user = getParameterValue(EMBEDDING_STORE_USER);
return EmbeddingStoreConfig.builder().provider(provider).baseUrl(baseUrl).apiKey(apiKey)
.persistPath(persistPath).databaseName(databaseName).timeOut(Long.valueOf(timeOut))
- .dimension(dimension).build();
+ .dimension(dimension).post(port).user(user).build();
}
private static ArrayList getCandidateValues() {
return Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
- EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name());
+ EmbeddingStoreType.MILVUS.name(),
+ EmbeddingStoreType.CHROMA.name(),
+ EmbeddingStoreType.PGVECTOR.name());
}
private static List getBaseUrlDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
- EmbeddingStoreType.CHROMA.name()),
+ EmbeddingStoreType.CHROMA.name(),
+ EmbeddingStoreType.PGVECTOR.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
- EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"));
+ EmbeddingStoreType.CHROMA.name(), "http://localhost:8000",
+ EmbeddingStoreType.PGVECTOR.name(), "127.0.0.1"));
}
private static List getApiKeyDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
- Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
- ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO));
+ Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()),
+ ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO,
+ EmbeddingStoreType.PGVECTOR.name(), DEMO));
}
private static List getPathDependency() {
@@ -94,13 +113,29 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
private static List getDimensionDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
- Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
- ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384"));
+ Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()),
+ ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384",
+ EmbeddingStoreType.PGVECTOR.name(), "768"));
}
private static List getDatabaseNameDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
- Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
- ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), ""));
+ Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()),
+ ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "",
+ EmbeddingStoreType.PGVECTOR.name(), "postgres"));
+ }
+
+ private static List getPostDependency() {
+ return getDependency(
+ EMBEDDING_STORE_PROVIDER.getName(),
+ Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()),
+ ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "54333"));
+ }
+
+ private static List getUserDependency() {
+ return getDependency(
+ EMBEDDING_STORE_PROVIDER.getName(),
+ Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()),
+ ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "pgvector"));
}
}
diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java b/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java
index e3703f6f8..eafdd60a0 100644
--- a/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java
+++ b/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java
@@ -22,4 +22,6 @@ public class EmbeddingStoreConfig implements Serializable {
private Long timeOut = 60L;
private Integer dimension;
private String databaseName;
+ private Integer post;
+ private String user;
}
diff --git a/common/src/main/java/dev/langchain4j/pgvector/spring/EmbeddingStoreProperties.java b/common/src/main/java/dev/langchain4j/pgvector/spring/EmbeddingStoreProperties.java
new file mode 100644
index 000000000..631e5b5b1
--- /dev/null
+++ b/common/src/main/java/dev/langchain4j/pgvector/spring/EmbeddingStoreProperties.java
@@ -0,0 +1,23 @@
+package dev.langchain4j.pgvector.spring;
+
+import dev.langchain4j.store.embedding.pgvector.MetadataStorageConfig;
+import lombok.Getter;
+import lombok.Setter;
+
+@Getter
+@Setter
+class EmbeddingStoreProperties {
+
+ private String host;
+ private Integer port;
+ private String user;
+ private String password;
+ private String database;
+ private String table;
+ private Integer dimension;
+ private Boolean useIndex;
+ private Integer indexListSize;
+ private Boolean createTable;
+ private Boolean dropTableFirst;
+ private MetadataStorageConfig metadataStorageConfig;
+}
diff --git a/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorAutoConfig.java b/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorAutoConfig.java
new file mode 100644
index 000000000..61ddfcb89
--- /dev/null
+++ b/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorAutoConfig.java
@@ -0,0 +1,20 @@
+package dev.langchain4j.pgvector.spring;
+
+import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
+import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
+import org.springframework.boot.context.properties.EnableConfigurationProperties;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+
+import static dev.langchain4j.pgvector.spring.Properties.PREFIX;
+
+@Configuration
+@EnableConfigurationProperties(Properties.class)
+public class PgvectorAutoConfig {
+
+ @Bean
+ @ConditionalOnProperty(PREFIX + ".embedding-store.host")
+ EmbeddingStoreFactory pgvectorChatModel(Properties properties) {
+ return new PgvectorEmbeddingStoreFactory(properties.getEmbeddingStore());
+ }
+}
diff --git a/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorEmbeddingStoreFactory.java
new file mode 100644
index 000000000..e0e8b758d
--- /dev/null
+++ b/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorEmbeddingStoreFactory.java
@@ -0,0 +1,46 @@
+package dev.langchain4j.pgvector.spring;
+
+import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
+import dev.langchain4j.data.segment.TextSegment;
+import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
+import dev.langchain4j.store.embedding.EmbeddingStore;
+import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
+import org.springframework.beans.BeanUtils;
+
+public class PgvectorEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
+ private final EmbeddingStoreProperties storeProperties;
+
+ public PgvectorEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
+ this(createPropertiesFromConfig(storeConfig));
+ }
+
+ public PgvectorEmbeddingStoreFactory(EmbeddingStoreProperties storeProperties) {
+ this.storeProperties = storeProperties;
+ }
+
+ private static EmbeddingStoreProperties createPropertiesFromConfig(
+ EmbeddingStoreConfig storeConfig) {
+ EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
+ BeanUtils.copyProperties(storeConfig, embeddingStore);
+ embeddingStore.setHost(storeConfig.getBaseUrl());
+ embeddingStore.setPort(storeConfig.getPost());
+ embeddingStore.setDatabase(storeConfig.getDatabaseName());
+ embeddingStore.setUser(storeConfig.getUser());
+ embeddingStore.setPassword(storeConfig.getApiKey());
+ return embeddingStore;
+ }
+
+ @Override
+ public EmbeddingStore createEmbeddingStore(String collectionName) {
+ return PgVectorEmbeddingStore.builder()
+ .host(storeProperties.getHost())
+ .port(storeProperties.getPort())
+ .database(storeProperties.getDatabase())
+ .user(storeProperties.getUser())
+ .password(storeProperties.getPassword())
+ .table(collectionName)
+ .dimension(storeProperties.getDimension())
+ .build();
+ }
+
+}
diff --git a/common/src/main/java/dev/langchain4j/pgvector/spring/Properties.java b/common/src/main/java/dev/langchain4j/pgvector/spring/Properties.java
new file mode 100644
index 000000000..6c9e1473e
--- /dev/null
+++ b/common/src/main/java/dev/langchain4j/pgvector/spring/Properties.java
@@ -0,0 +1,16 @@
+package dev.langchain4j.pgvector.spring;
+
+import lombok.Getter;
+import lombok.Setter;
+import org.springframework.boot.context.properties.ConfigurationProperties;
+import org.springframework.boot.context.properties.NestedConfigurationProperty;
+
+@Getter
+@Setter
+@ConfigurationProperties(prefix = Properties.PREFIX)
+public class Properties {
+
+ static final String PREFIX = "langchain4j.pgvector";
+
+ @NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
+}
diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java
index cf86025ec..4113bd484 100644
--- a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java
+++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java
@@ -6,6 +6,7 @@ import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory;
+import dev.langchain4j.pgvector.spring.PgvectorEmbeddingStoreFactory;
import org.apache.commons.lang3.StringUtils;
import java.util.Map;
@@ -34,6 +35,11 @@ public class EmbeddingStoreFactoryProvider {
return factoryMap.computeIfAbsent(embeddingStoreConfig,
storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig));
}
+ if (EmbeddingStoreType.PGVECTOR.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
+ return factoryMap.computeIfAbsent(
+ embeddingStoreConfig,
+ storeConfig -> new PgvectorEmbeddingStoreFactory(storeConfig));
+ }
if (EmbeddingStoreType.IN_MEMORY.name()
.equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(embeddingStoreConfig,
diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java
index 068ac0ada..bb533b0f6 100644
--- a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java
+++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java
@@ -1,5 +1,5 @@
package dev.langchain4j.store.embedding;
public enum EmbeddingStoreType {
- IN_MEMORY, MILVUS, CHROMA
+ IN_MEMORY, MILVUS, CHROMA, PGVECTOR
}
diff --git a/pom.xml b/pom.xml
index bccb7641c..8705295b7 100644
--- a/pom.xml
+++ b/pom.xml
@@ -172,6 +172,11 @@
langchain4j-milvus
${langchain4j.version}
+
+ dev.langchain4j
+ langchain4j-pgvector
+ ${langchain4j.version}
+
dev.langchain4j
langchain4j-chatglm