mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-30 04:54:25 +08:00
Compare commits
3 Commits
f412ae4539
...
d942d35c93
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d942d35c93 | ||
|
|
198c7c69e6 | ||
|
|
cb139a54e8 |
@@ -174,6 +174,10 @@
|
|||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-milvus</artifactId>
|
<artifactId>langchain4j-milvus</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-opensearch</artifactId>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-pgvector</artifactId>
|
<artifactId>langchain4j-pgvector</artifactId>
|
||||||
@@ -242,6 +246,10 @@
|
|||||||
<groupId>com.google.code.gson</groupId>
|
<groupId>com.google.code.gson</groupId>
|
||||||
<artifactId>gson</artifactId>
|
<artifactId>gson</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.amazonaws</groupId>
|
||||||
|
<artifactId>aws-java-sdk</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.codehaus.woodstox</groupId>
|
<groupId>org.codehaus.woodstox</groupId>
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
|||||||
|
|
||||||
public static final Parameter EMBEDDING_STORE_PROVIDER = new Parameter(
|
public static final Parameter EMBEDDING_STORE_PROVIDER = new Parameter(
|
||||||
"s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), "向量库类型",
|
"s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), "向量库类型",
|
||||||
"目前支持四种类型:IN_MEMORY、MILVUS、CHROMA、PGVECTOR", "list", MODULE_NAME, getCandidateValues());
|
"目前支持四种类型:IN_MEMORY、MILVUS、CHROMA、PGVECTOR、OPENSEARCH", "list", MODULE_NAME, getCandidateValues());
|
||||||
|
|
||||||
public static final Parameter EMBEDDING_STORE_BASE_URL =
|
public static final Parameter EMBEDDING_STORE_BASE_URL =
|
||||||
new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", MODULE_NAME,
|
new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", MODULE_NAME,
|
||||||
@@ -87,16 +87,19 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
|||||||
private static ArrayList<String> getCandidateValues() {
|
private static ArrayList<String> getCandidateValues() {
|
||||||
return Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
|
return Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
|
||||||
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name(),
|
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name(),
|
||||||
EmbeddingStoreType.PGVECTOR.name());
|
EmbeddingStoreType.PGVECTOR.name(), EmbeddingStoreType.OPENSEARCH.name());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||||
EmbeddingStoreType.CHROMA.name(), EmbeddingStoreType.PGVECTOR.name()),
|
EmbeddingStoreType.CHROMA.name(),
|
||||||
|
EmbeddingStoreType.PGVECTOR.name(),
|
||||||
|
EmbeddingStoreType.OPENSEARCH.name()),
|
||||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
|
||||||
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000",
|
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000",
|
||||||
EmbeddingStoreType.PGVECTOR.name(), "127.0.0.1"));
|
EmbeddingStoreType.PGVECTOR.name(), "127.0.0.1",
|
||||||
|
EmbeddingStoreType.OPENSEARCH.name(), "http://localhost:9200"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||||
@@ -114,17 +117,21 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
|||||||
private static List<Parameter.Dependency> getDimensionDependency() {
|
private static List<Parameter.Dependency> getDimensionDependency() {
|
||||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||||
EmbeddingStoreType.PGVECTOR.name()),
|
EmbeddingStoreType.PGVECTOR.name(),
|
||||||
|
EmbeddingStoreType.OPENSEARCH.name()),
|
||||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384",
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384",
|
||||||
EmbeddingStoreType.PGVECTOR.name(), "512"));
|
EmbeddingStoreType.PGVECTOR.name(), "512",
|
||||||
|
EmbeddingStoreType.OPENSEARCH.name(), "512"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getDatabaseNameDependency() {
|
private static List<Parameter.Dependency> getDatabaseNameDependency() {
|
||||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||||
EmbeddingStoreType.PGVECTOR.name()),
|
EmbeddingStoreType.PGVECTOR.name(),
|
||||||
|
EmbeddingStoreType.OPENSEARCH.name()),
|
||||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "",
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "",
|
||||||
EmbeddingStoreType.PGVECTOR.name(), "postgres"));
|
EmbeddingStoreType.PGVECTOR.name(), "postgres",
|
||||||
|
EmbeddingStoreType.OPENSEARCH.name(), "ai_sql"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getPortDependency() {
|
private static List<Parameter.Dependency> getPortDependency() {
|
||||||
@@ -136,16 +143,20 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
|
|||||||
private static List<Parameter.Dependency> getUserDependency() {
|
private static List<Parameter.Dependency> getUserDependency() {
|
||||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||||
EmbeddingStoreType.PGVECTOR.name()),
|
EmbeddingStoreType.PGVECTOR.name(),
|
||||||
|
EmbeddingStoreType.OPENSEARCH.name()),
|
||||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "milvus",
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "milvus",
|
||||||
EmbeddingStoreType.PGVECTOR.name(), "postgres"));
|
EmbeddingStoreType.PGVECTOR.name(), "postgres",
|
||||||
|
EmbeddingStoreType.OPENSEARCH.name(), "opensearch"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getPasswordDependency() {
|
private static List<Parameter.Dependency> getPasswordDependency() {
|
||||||
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
|
||||||
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
|
||||||
EmbeddingStoreType.PGVECTOR.name()),
|
EmbeddingStoreType.PGVECTOR.name(),
|
||||||
|
EmbeddingStoreType.OPENSEARCH.name()),
|
||||||
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "milvus",
|
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "milvus",
|
||||||
EmbeddingStoreType.PGVECTOR.name(), "postgres"));
|
EmbeddingStoreType.PGVECTOR.name(), "postgres",
|
||||||
|
EmbeddingStoreType.OPENSEARCH.name(), "opensearch"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ public class ChatModelConfig implements Serializable {
|
|||||||
private String baseUrl;
|
private String baseUrl;
|
||||||
private String apiKey;
|
private String apiKey;
|
||||||
private String modelName;
|
private String modelName;
|
||||||
|
private String apiVersion;
|
||||||
private Double temperature = 0.0d;
|
private Double temperature = 0.0d;
|
||||||
private Long timeOut = 60L;
|
private Long timeOut = 60L;
|
||||||
private String endpoint;
|
private String endpoint;
|
||||||
|
|||||||
@@ -34,6 +34,9 @@ public class ChatModelParameters {
|
|||||||
public static final Parameter CHAT_MODEL_API_KEY = new Parameter("apiKey", "", "ApiKey", "",
|
public static final Parameter CHAT_MODEL_API_KEY = new Parameter("apiKey", "", "ApiKey", "",
|
||||||
"password", MODULE_NAME, null, getApiKeyDependency());
|
"password", MODULE_NAME, null, getApiKeyDependency());
|
||||||
|
|
||||||
|
public static final Parameter CHAT_MODEL_API_VERSION = new Parameter("apiVersion", "2024-02-01", "ApiVersion", "",
|
||||||
|
"string", MODULE_NAME, null, getApiVersionDependency());
|
||||||
|
|
||||||
public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("endpoint", "llama_2_70b",
|
public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("endpoint", "llama_2_70b",
|
||||||
"Endpoint", "", "string", MODULE_NAME, null, getEndpointDependency());
|
"Endpoint", "", "string", MODULE_NAME, null, getEndpointDependency());
|
||||||
|
|
||||||
@@ -51,7 +54,7 @@ public class ChatModelParameters {
|
|||||||
|
|
||||||
public static List<Parameter> getParameters() {
|
public static List<Parameter> getParameters() {
|
||||||
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
|
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
|
||||||
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
|
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME, CHAT_MODEL_API_VERSION,
|
||||||
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT);
|
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,6 +93,12 @@ public class ChatModelParameters {
|
|||||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static List<Parameter.Dependency> getApiVersionDependency() {
|
||||||
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||||
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_API_VERSION));
|
||||||
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||||
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
|
return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
|
||||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||||||
private final OpenAiClient client;
|
private final OpenAiClient client;
|
||||||
private final String baseUrl;
|
private final String baseUrl;
|
||||||
private final String modelName;
|
private final String modelName;
|
||||||
|
private final String apiVersion;
|
||||||
private final Double temperature;
|
private final Double temperature;
|
||||||
private final Double topP;
|
private final Double topP;
|
||||||
private final List<String> stop;
|
private final List<String> stop;
|
||||||
@@ -88,7 +89,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||||||
private final List<ChatModelListener> listeners;
|
private final List<ChatModelListener> listeners;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName,
|
public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName, String apiVersion,
|
||||||
Double temperature, Double topP, List<String> stop, Integer maxTokens,
|
Double temperature, Double topP, List<String> stop, Integer maxTokens,
|
||||||
Double presencePenalty, Double frequencyPenalty, Map<String, Integer> logitBias,
|
Double presencePenalty, Double frequencyPenalty, Map<String, Integer> logitBias,
|
||||||
String responseFormat, Boolean strictJsonSchema, Integer seed, String user,
|
String responseFormat, Boolean strictJsonSchema, Integer seed, String user,
|
||||||
@@ -104,12 +105,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||||||
|
|
||||||
timeout = getOrDefault(timeout, ofSeconds(60));
|
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||||
|
|
||||||
this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl)
|
this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl).apiVersion(apiVersion)
|
||||||
.organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout)
|
.organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout)
|
||||||
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests)
|
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests)
|
||||||
.logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
|
.logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
|
||||||
.customHeaders(customHeaders).build();
|
.customHeaders(customHeaders).build();
|
||||||
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
|
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
|
||||||
|
this.apiVersion = apiVersion;
|
||||||
this.temperature = getOrDefault(temperature, 0.7);
|
this.temperature = getOrDefault(temperature, 0.7);
|
||||||
this.topP = topP;
|
this.topP = topP;
|
||||||
this.stop = stop;
|
this.stop = stop;
|
||||||
|
|||||||
@@ -0,0 +1,29 @@
|
|||||||
|
package dev.langchain4j.opensearch.spring;
|
||||||
|
|
||||||
|
import io.milvus.common.clientenum.ConsistencyLevelEnum;
|
||||||
|
import io.milvus.param.IndexType;
|
||||||
|
import io.milvus.param.MetricType;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
class EmbeddingStoreProperties {
|
||||||
|
|
||||||
|
private String uri;
|
||||||
|
private String host;
|
||||||
|
private Integer port;
|
||||||
|
private String serviceName;
|
||||||
|
private String region;
|
||||||
|
private String collectionName;
|
||||||
|
private Integer dimension;
|
||||||
|
private IndexType indexType;
|
||||||
|
private MetricType metricType;
|
||||||
|
private String token;
|
||||||
|
private String user;
|
||||||
|
private String password;
|
||||||
|
private ConsistencyLevelEnum consistencyLevel;
|
||||||
|
private Boolean retrieveEmbeddingsOnSearch;
|
||||||
|
private String databaseName;
|
||||||
|
private Boolean autoFlushOnInsert;
|
||||||
|
}
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package dev.langchain4j.opensearch.spring;
|
||||||
|
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||||
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||||
|
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||||
|
import org.springframework.context.annotation.Bean;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
|
||||||
|
import static dev.langchain4j.opensearch.spring.Properties.PREFIX;
|
||||||
|
@Configuration
|
||||||
|
@EnableConfigurationProperties(dev.langchain4j.opensearch.spring.Properties.class)
|
||||||
|
public class OpenSearchAutoConfig {
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
@ConditionalOnProperty(PREFIX + ".embedding-store.uri")
|
||||||
|
EmbeddingStoreFactory milvusChatModel(Properties properties) {
|
||||||
|
return new OpenSearchEmbeddingStoreFactory(properties.getEmbeddingStore());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
package dev.langchain4j.opensearch.spring;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory;
|
||||||
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
|
import dev.langchain4j.store.embedding.opensearch.OpenSearchEmbeddingStore;
|
||||||
|
import org.apache.hc.client5.http.auth.AuthScope;
|
||||||
|
import org.apache.hc.client5.http.auth.UsernamePasswordCredentials;
|
||||||
|
import org.apache.hc.client5.http.impl.auth.BasicCredentialsProvider;
|
||||||
|
import org.apache.hc.core5.http.HttpHost;
|
||||||
|
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
|
||||||
|
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author zyc
|
||||||
|
*/
|
||||||
|
public class OpenSearchEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||||
|
private final EmbeddingStoreProperties storeProperties;
|
||||||
|
|
||||||
|
public OpenSearchEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) {
|
||||||
|
this(createPropertiesFromConfig(storeConfig));
|
||||||
|
}
|
||||||
|
|
||||||
|
public OpenSearchEmbeddingStoreFactory(EmbeddingStoreProperties storeProperties) {
|
||||||
|
this.storeProperties = storeProperties;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static EmbeddingStoreProperties createPropertiesFromConfig(
|
||||||
|
EmbeddingStoreConfig storeConfig) {
|
||||||
|
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
|
||||||
|
BeanUtils.copyProperties(storeConfig, embeddingStore);
|
||||||
|
embeddingStore.setUri(storeConfig.getBaseUrl());
|
||||||
|
embeddingStore.setToken(storeConfig.getApiKey());
|
||||||
|
embeddingStore.setDatabaseName(storeConfig.getDatabaseName());
|
||||||
|
return embeddingStore;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
|
||||||
|
final AwsSdk2TransportOptions options = AwsSdk2TransportOptions.builder()
|
||||||
|
.setCredentials(StaticCredentialsProvider.create(AwsBasicCredentials.create(storeProperties.getUser(), storeProperties.getPassword())))
|
||||||
|
.build();
|
||||||
|
final String indexName = storeProperties.getDatabaseName() + "_" + collectionName;
|
||||||
|
return OpenSearchEmbeddingStore.builder().serviceName(storeProperties.getServiceName())
|
||||||
|
.serverUrl(storeProperties.getUri())
|
||||||
|
.region(storeProperties.getRegion())
|
||||||
|
.indexName(indexName)
|
||||||
|
.userName(storeProperties.getUser())
|
||||||
|
.password(storeProperties.getPassword())
|
||||||
|
.apiKey(storeProperties.getToken())
|
||||||
|
.options(options)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package dev.langchain4j.opensearch.spring;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||||
|
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
@ConfigurationProperties(prefix = Properties.PREFIX)
|
||||||
|
public class Properties {
|
||||||
|
|
||||||
|
static final String PREFIX = "langchain4j.opensearch";
|
||||||
|
|
||||||
|
@NestedConfigurationProperty
|
||||||
|
dev.langchain4j.opensearch.spring.EmbeddingStoreProperties embeddingStore;
|
||||||
|
}
|
||||||
@@ -18,11 +18,13 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
|||||||
public static final String DEFAULT_BASE_URL = "https://api.openai.com/v1";
|
public static final String DEFAULT_BASE_URL = "https://api.openai.com/v1";
|
||||||
public static final String DEFAULT_MODEL_NAME = "gpt-4o-mini";
|
public static final String DEFAULT_MODEL_NAME = "gpt-4o-mini";
|
||||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
|
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
|
||||||
|
public static final String DEFAULT_API_VERSION = "2024-02-01";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||||
return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||||
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
|
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
|
||||||
|
.apiVersion(modelConfig.getApiVersion())
|
||||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||||
.maxRetries(modelConfig.getMaxRetries())
|
.maxRetries(modelConfig.getMaxRetries())
|
||||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import com.tencent.supersonic.common.util.ContextUtils;
|
|||||||
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
|
import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory;
|
||||||
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory;
|
||||||
import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory;
|
import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory;
|
||||||
|
import dev.langchain4j.opensearch.spring.OpenSearchEmbeddingStoreFactory;
|
||||||
import dev.langchain4j.pgvector.spring.PgvectorEmbeddingStoreFactory;
|
import dev.langchain4j.pgvector.spring.PgvectorEmbeddingStoreFactory;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
@@ -45,6 +46,11 @@ public class EmbeddingStoreFactoryProvider {
|
|||||||
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||||
storeConfig -> new InMemoryEmbeddingStoreFactory(storeConfig));
|
storeConfig -> new InMemoryEmbeddingStoreFactory(storeConfig));
|
||||||
}
|
}
|
||||||
|
if (EmbeddingStoreType.OPENSEARCH.name()
|
||||||
|
.equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||||
|
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||||
|
storeConfig -> new OpenSearchEmbeddingStoreFactory(storeConfig));
|
||||||
|
}
|
||||||
throw new RuntimeException("Unsupported EmbeddingStoreFactory provider: "
|
throw new RuntimeException("Unsupported EmbeddingStoreFactory provider: "
|
||||||
+ embeddingStoreConfig.getProvider());
|
+ embeddingStoreConfig.getProvider());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
package dev.langchain4j.store.embedding;
|
package dev.langchain4j.store.embedding;
|
||||||
|
|
||||||
public enum EmbeddingStoreType {
|
public enum EmbeddingStoreType {
|
||||||
IN_MEMORY, MILVUS, CHROMA, PGVECTOR
|
IN_MEMORY, MILVUS, CHROMA, PGVECTOR, OPENSEARCH
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -300,6 +300,15 @@ public class ModelConverter {
|
|||||||
private static ModelDetail updateModelDetail(ModelReq modelReq) {
|
private static ModelDetail updateModelDetail(ModelReq modelReq) {
|
||||||
ModelDetail modelDetail = new ModelDetail();
|
ModelDetail modelDetail = new ModelDetail();
|
||||||
List<Measure> measures = modelReq.getModelDetail().getMeasures();
|
List<Measure> measures = modelReq.getModelDetail().getMeasures();
|
||||||
|
List<Dimension> dimensions = modelReq.getModelDetail().getDimensions();
|
||||||
|
if (!CollectionUtils.isEmpty(dimensions)) {
|
||||||
|
for (Dimension dimension : dimensions) {
|
||||||
|
if (StringUtils.isNotBlank(dimension.getBizName())
|
||||||
|
&& StringUtils.isBlank(dimension.getExpr())) {
|
||||||
|
dimension.setExpr(dimension.getBizName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if (measures == null) {
|
if (measures == null) {
|
||||||
measures = Lists.newArrayList();
|
measures = Lists.newArrayList();
|
||||||
}
|
}
|
||||||
|
|||||||
13
pom.xml
13
pom.xml
@@ -66,7 +66,7 @@
|
|||||||
<mockito-inline.version>4.5.1</mockito-inline.version>
|
<mockito-inline.version>4.5.1</mockito-inline.version>
|
||||||
<easyexcel.version>2.2.6</easyexcel.version>
|
<easyexcel.version>2.2.6</easyexcel.version>
|
||||||
<poi.version>3.17</poi.version>
|
<poi.version>3.17</poi.version>
|
||||||
<langchain4j.version>0.34.0</langchain4j.version>
|
<langchain4j.version>0.35.0</langchain4j.version>
|
||||||
<langchain4j.embedding.version>0.27.1</langchain4j.embedding.version>
|
<langchain4j.embedding.version>0.27.1</langchain4j.embedding.version>
|
||||||
<!-- <postgresql.version>42.7.1</postgresql.version>-->
|
<!-- <postgresql.version>42.7.1</postgresql.version>-->
|
||||||
<st.version>4.0.8</st.version>
|
<st.version>4.0.8</st.version>
|
||||||
@@ -79,6 +79,7 @@
|
|||||||
<spotless.skip>false</spotless.skip>
|
<spotless.skip>false</spotless.skip>
|
||||||
<stax2.version>4.2.1</stax2.version>
|
<stax2.version>4.2.1</stax2.version>
|
||||||
<io.springfox.version>3.0.0</io.springfox.version>
|
<io.springfox.version>3.0.0</io.springfox.version>
|
||||||
|
<aws-java-sdk.version>1.12.780</aws-java-sdk.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
<dependencyManagement>
|
<dependencyManagement>
|
||||||
@@ -173,6 +174,11 @@
|
|||||||
<artifactId>langchain4j-milvus</artifactId>
|
<artifactId>langchain4j-milvus</artifactId>
|
||||||
<version>${langchain4j.version}</version>
|
<version>${langchain4j.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>dev.langchain4j</groupId>
|
||||||
|
<artifactId>langchain4j-opensearch</artifactId>
|
||||||
|
<version>${langchain4j.version}</version>
|
||||||
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-pgvector</artifactId>
|
<artifactId>langchain4j-pgvector</artifactId>
|
||||||
@@ -213,6 +219,11 @@
|
|||||||
<artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
|
<artifactId>springdoc-openapi-starter-webmvc-ui</artifactId>
|
||||||
<version>2.1.0</version>
|
<version>2.1.0</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.amazonaws</groupId>
|
||||||
|
<artifactId>aws-java-sdk</artifactId>
|
||||||
|
<version>${aws-java-sdk.version}</version>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</dependencyManagement>
|
</dependencyManagement>
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user