From 1ef642d0dd7c8f6aa47b017226097c1778e80fd3 Mon Sep 17 00:00:00 2001 From: Zhengyang Jia <67008026+littleforestjia@users.noreply.github.com> Date: Tue, 15 Oct 2024 09:16:32 +0800 Subject: [PATCH] [improvement][common]Add pgvector vector library adaptation. (#1800) --- common/pom.xml | 4 ++ .../config/EmbeddingStoreParameterConfig.java | 57 +++++++++++++++---- .../common/pojo/EmbeddingStoreConfig.java | 2 + .../spring/EmbeddingStoreProperties.java | 23 ++++++++ .../pgvector/spring/PgvectorAutoConfig.java | 20 +++++++ .../spring/PgvectorEmbeddingStoreFactory.java | 46 +++++++++++++++ .../pgvector/spring/Properties.java | 16 ++++++ .../EmbeddingStoreFactoryProvider.java | 6 ++ .../store/embedding/EmbeddingStoreType.java | 2 +- pom.xml | 5 ++ 10 files changed, 169 insertions(+), 12 deletions(-) create mode 100644 common/src/main/java/dev/langchain4j/pgvector/spring/EmbeddingStoreProperties.java create mode 100644 common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorAutoConfig.java create mode 100644 common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorEmbeddingStoreFactory.java create mode 100644 common/src/main/java/dev/langchain4j/pgvector/spring/Properties.java diff --git a/common/pom.xml b/common/pom.xml index b8ecd2ddf..27b847cc9 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -174,6 +174,10 @@ dev.langchain4j langchain4j-milvus + + dev.langchain4j + langchain4j-pgvector + dev.langchain4j langchain4j-azure-open-ai diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java index 4918d770b..33f2bc14a 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java @@ -19,7 +19,7 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig { public static final Parameter EMBEDDING_STORE_PROVIDER = new Parameter( "s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), "向量库类型", - "目前支持三种类型:IN_MEMORY、MILVUS、CHROMA", "list", MODULE_NAME, getCandidateValues()); + "目前支持四种类型:IN_MEMORY、MILVUS、CHROMA、PGVECTOR", "list", MODULE_NAME, getCandidateValues()); public static final Parameter EMBEDDING_STORE_BASE_URL = new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", MODULE_NAME, @@ -44,9 +44,18 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig { new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string", MODULE_NAME, null, getDatabaseNameDependency()); + public static final Parameter EMBEDDING_STORE_POST = + new Parameter("s2.embedding.store.post", "", "端口", "", "number", MODULE_NAME, null, + getPostDependency()); + + public static final Parameter EMBEDDING_STORE_USER = + new Parameter("s2.embedding.store.user", "", "用户名", "", "string", MODULE_NAME, null, + getUserDependency()); + @Override public List getSysParameters() { return Lists.newArrayList(EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL, + EMBEDDING_STORE_POST, EMBEDDING_STORE_USER, EMBEDDING_STORE_API_KEY, EMBEDDING_STORE_DATABASE_NAME, EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION); } @@ -62,28 +71,38 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig { if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) { dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION)); } + Integer port = null; + if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_POST))) { + port = Integer.valueOf(getParameterValue(EMBEDDING_STORE_POST)); + } + String user = getParameterValue(EMBEDDING_STORE_USER); return EmbeddingStoreConfig.builder().provider(provider).baseUrl(baseUrl).apiKey(apiKey) .persistPath(persistPath).databaseName(databaseName).timeOut(Long.valueOf(timeOut)) - .dimension(dimension).build(); + .dimension(dimension).post(port).user(user).build(); } private static ArrayList getCandidateValues() { return Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(), - EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name()); + EmbeddingStoreType.MILVUS.name(), + EmbeddingStoreType.CHROMA.name(), + EmbeddingStoreType.PGVECTOR.name()); } private static List getBaseUrlDependency() { return getDependency(EMBEDDING_STORE_PROVIDER.getName(), Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), - EmbeddingStoreType.CHROMA.name()), + EmbeddingStoreType.CHROMA.name(), + EmbeddingStoreType.PGVECTOR.name()), ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "http://localhost:19530", - EmbeddingStoreType.CHROMA.name(), "http://localhost:8000")); + EmbeddingStoreType.CHROMA.name(), "http://localhost:8000", + EmbeddingStoreType.PGVECTOR.name(), "127.0.0.1")); } private static List getApiKeyDependency() { return getDependency(EMBEDDING_STORE_PROVIDER.getName(), - Lists.newArrayList(EmbeddingStoreType.MILVUS.name()), - ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO)); + Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()), + ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO, + EmbeddingStoreType.PGVECTOR.name(), DEMO)); } private static List getPathDependency() { @@ -94,13 +113,29 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig { private static List getDimensionDependency() { return getDependency(EMBEDDING_STORE_PROVIDER.getName(), - Lists.newArrayList(EmbeddingStoreType.MILVUS.name()), - ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384")); + Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()), + ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384", + EmbeddingStoreType.PGVECTOR.name(), "768")); } private static List getDatabaseNameDependency() { return getDependency(EMBEDDING_STORE_PROVIDER.getName(), - Lists.newArrayList(EmbeddingStoreType.MILVUS.name()), - ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "")); + Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.PGVECTOR.name()), + ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "", + EmbeddingStoreType.PGVECTOR.name(), "postgres")); + } + + private static List getPostDependency() { + return getDependency( + EMBEDDING_STORE_PROVIDER.getName(), + Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()), + ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "54333")); + } + + private static List getUserDependency() { + return getDependency( + EMBEDDING_STORE_PROVIDER.getName(), + Lists.newArrayList(EmbeddingStoreType.PGVECTOR.name()), + ImmutableMap.of(EmbeddingStoreType.PGVECTOR.name(), "pgvector")); } } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java b/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java index e3703f6f8..eafdd60a0 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/EmbeddingStoreConfig.java @@ -22,4 +22,6 @@ public class EmbeddingStoreConfig implements Serializable { private Long timeOut = 60L; private Integer dimension; private String databaseName; + private Integer post; + private String user; } diff --git a/common/src/main/java/dev/langchain4j/pgvector/spring/EmbeddingStoreProperties.java b/common/src/main/java/dev/langchain4j/pgvector/spring/EmbeddingStoreProperties.java new file mode 100644 index 000000000..631e5b5b1 --- /dev/null +++ b/common/src/main/java/dev/langchain4j/pgvector/spring/EmbeddingStoreProperties.java @@ -0,0 +1,23 @@ +package dev.langchain4j.pgvector.spring; + +import dev.langchain4j.store.embedding.pgvector.MetadataStorageConfig; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +class EmbeddingStoreProperties { + + private String host; + private Integer port; + private String user; + private String password; + private String database; + private String table; + private Integer dimension; + private Boolean useIndex; + private Integer indexListSize; + private Boolean createTable; + private Boolean dropTableFirst; + private MetadataStorageConfig metadataStorageConfig; +} diff --git a/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorAutoConfig.java b/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorAutoConfig.java new file mode 100644 index 000000000..61ddfcb89 --- /dev/null +++ b/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorAutoConfig.java @@ -0,0 +1,20 @@ +package dev.langchain4j.pgvector.spring; + +import dev.langchain4j.store.embedding.EmbeddingStoreFactory; +import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import static dev.langchain4j.pgvector.spring.Properties.PREFIX; + +@Configuration +@EnableConfigurationProperties(Properties.class) +public class PgvectorAutoConfig { + + @Bean + @ConditionalOnProperty(PREFIX + ".embedding-store.host") + EmbeddingStoreFactory pgvectorChatModel(Properties properties) { + return new PgvectorEmbeddingStoreFactory(properties.getEmbeddingStore()); + } +} diff --git a/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorEmbeddingStoreFactory.java new file mode 100644 index 000000000..e0e8b758d --- /dev/null +++ b/common/src/main/java/dev/langchain4j/pgvector/spring/PgvectorEmbeddingStoreFactory.java @@ -0,0 +1,46 @@ +package dev.langchain4j.pgvector.spring; + +import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore; +import org.springframework.beans.BeanUtils; + +public class PgvectorEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { + private final EmbeddingStoreProperties storeProperties; + + public PgvectorEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) { + this(createPropertiesFromConfig(storeConfig)); + } + + public PgvectorEmbeddingStoreFactory(EmbeddingStoreProperties storeProperties) { + this.storeProperties = storeProperties; + } + + private static EmbeddingStoreProperties createPropertiesFromConfig( + EmbeddingStoreConfig storeConfig) { + EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties(); + BeanUtils.copyProperties(storeConfig, embeddingStore); + embeddingStore.setHost(storeConfig.getBaseUrl()); + embeddingStore.setPort(storeConfig.getPost()); + embeddingStore.setDatabase(storeConfig.getDatabaseName()); + embeddingStore.setUser(storeConfig.getUser()); + embeddingStore.setPassword(storeConfig.getApiKey()); + return embeddingStore; + } + + @Override + public EmbeddingStore createEmbeddingStore(String collectionName) { + return PgVectorEmbeddingStore.builder() + .host(storeProperties.getHost()) + .port(storeProperties.getPort()) + .database(storeProperties.getDatabase()) + .user(storeProperties.getUser()) + .password(storeProperties.getPassword()) + .table(collectionName) + .dimension(storeProperties.getDimension()) + .build(); + } + +} diff --git a/common/src/main/java/dev/langchain4j/pgvector/spring/Properties.java b/common/src/main/java/dev/langchain4j/pgvector/spring/Properties.java new file mode 100644 index 000000000..6c9e1473e --- /dev/null +++ b/common/src/main/java/dev/langchain4j/pgvector/spring/Properties.java @@ -0,0 +1,16 @@ +package dev.langchain4j.pgvector.spring; + +import lombok.Getter; +import lombok.Setter; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.context.properties.NestedConfigurationProperty; + +@Getter +@Setter +@ConfigurationProperties(prefix = Properties.PREFIX) +public class Properties { + + static final String PREFIX = "langchain4j.pgvector"; + + @NestedConfigurationProperty EmbeddingStoreProperties embeddingStore; +} diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java index cf86025ec..4113bd484 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java @@ -6,6 +6,7 @@ import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory; import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory; import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory; +import dev.langchain4j.pgvector.spring.PgvectorEmbeddingStoreFactory; import org.apache.commons.lang3.StringUtils; import java.util.Map; @@ -34,6 +35,11 @@ public class EmbeddingStoreFactoryProvider { return factoryMap.computeIfAbsent(embeddingStoreConfig, storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig)); } + if (EmbeddingStoreType.PGVECTOR.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) { + return factoryMap.computeIfAbsent( + embeddingStoreConfig, + storeConfig -> new PgvectorEmbeddingStoreFactory(storeConfig)); + } if (EmbeddingStoreType.IN_MEMORY.name() .equalsIgnoreCase(embeddingStoreConfig.getProvider())) { return factoryMap.computeIfAbsent(embeddingStoreConfig, diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java index 068ac0ada..bb533b0f6 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java @@ -1,5 +1,5 @@ package dev.langchain4j.store.embedding; public enum EmbeddingStoreType { - IN_MEMORY, MILVUS, CHROMA + IN_MEMORY, MILVUS, CHROMA, PGVECTOR } diff --git a/pom.xml b/pom.xml index bccb7641c..8705295b7 100644 --- a/pom.xml +++ b/pom.xml @@ -172,6 +172,11 @@ langchain4j-milvus ${langchain4j.version} + + dev.langchain4j + langchain4j-pgvector + ${langchain4j.version} + dev.langchain4j langchain4j-chatglm