mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)[build] Use Spotless to customize the code formatting (#1750)
This commit is contained in:
@@ -24,11 +24,8 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
|
||||
@Override
|
||||
public EmbeddingStore createEmbeddingStore(String collectionName) {
|
||||
return ChromaEmbeddingStore.builder()
|
||||
.baseUrl(storeProperties.getBaseUrl())
|
||||
.collectionName(collectionName)
|
||||
.timeout(storeProperties.getTimeout())
|
||||
.build();
|
||||
return ChromaEmbeddingStore.builder().baseUrl(storeProperties.getBaseUrl())
|
||||
.collectionName(collectionName).timeout(storeProperties.getTimeout()).build();
|
||||
}
|
||||
|
||||
private static EmbeddingStoreProperties createPropertiesFromConfig(
|
||||
|
||||
@@ -12,5 +12,6 @@ public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.chroma";
|
||||
|
||||
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingStoreProperties embeddingStore;
|
||||
}
|
||||
|
||||
@@ -20,18 +20,15 @@ public class DashscopeAutoConfig {
|
||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
||||
QwenChatModel qwenChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
||||
return QwenChatModel.builder()
|
||||
.baseUrl(chatModelProperties.getBaseUrl())
|
||||
return QwenChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.modelName(chatModelProperties.getModelName())
|
||||
.topP(chatModelProperties.getTopP())
|
||||
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
|
||||
.topK(chatModelProperties.getTopK())
|
||||
.enableSearch(chatModelProperties.getEnableSearch())
|
||||
.seed(chatModelProperties.getSeed())
|
||||
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.stops(chatModelProperties.getStops())
|
||||
.maxTokens(chatModelProperties.getMaxTokens())
|
||||
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
|
||||
.build();
|
||||
}
|
||||
|
||||
@@ -39,18 +36,15 @@ public class DashscopeAutoConfig {
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
||||
QwenStreamingChatModel qwenStreamingChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
||||
return QwenStreamingChatModel.builder()
|
||||
.baseUrl(chatModelProperties.getBaseUrl())
|
||||
return QwenStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.modelName(chatModelProperties.getModelName())
|
||||
.topP(chatModelProperties.getTopP())
|
||||
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
|
||||
.topK(chatModelProperties.getTopK())
|
||||
.enableSearch(chatModelProperties.getEnableSearch())
|
||||
.seed(chatModelProperties.getSeed())
|
||||
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.stops(chatModelProperties.getStops())
|
||||
.maxTokens(chatModelProperties.getMaxTokens())
|
||||
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
|
||||
.build();
|
||||
}
|
||||
|
||||
@@ -58,47 +52,33 @@ public class DashscopeAutoConfig {
|
||||
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
|
||||
QwenLanguageModel qwenLanguageModel(Properties properties) {
|
||||
ChatModelProperties languageModel = properties.getLanguageModel();
|
||||
return QwenLanguageModel.builder()
|
||||
.baseUrl(languageModel.getBaseUrl())
|
||||
.apiKey(languageModel.getApiKey())
|
||||
.modelName(languageModel.getModelName())
|
||||
.topP(languageModel.getTopP())
|
||||
.topK(languageModel.getTopK())
|
||||
.enableSearch(languageModel.getEnableSearch())
|
||||
.seed(languageModel.getSeed())
|
||||
return QwenLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
|
||||
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
|
||||
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
|
||||
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
|
||||
.repetitionPenalty(languageModel.getRepetitionPenalty())
|
||||
.temperature(languageModel.getTemperature())
|
||||
.stops(languageModel.getStops())
|
||||
.maxTokens(languageModel.getMaxTokens())
|
||||
.build();
|
||||
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
|
||||
.maxTokens(languageModel.getMaxTokens()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
|
||||
QwenStreamingLanguageModel qwenStreamingLanguageModel(Properties properties) {
|
||||
ChatModelProperties languageModel = properties.getStreamingLanguageModel();
|
||||
return QwenStreamingLanguageModel.builder()
|
||||
.baseUrl(languageModel.getBaseUrl())
|
||||
.apiKey(languageModel.getApiKey())
|
||||
.modelName(languageModel.getModelName())
|
||||
.topP(languageModel.getTopP())
|
||||
.topK(languageModel.getTopK())
|
||||
.enableSearch(languageModel.getEnableSearch())
|
||||
.seed(languageModel.getSeed())
|
||||
return QwenStreamingLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
|
||||
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
|
||||
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
|
||||
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
|
||||
.repetitionPenalty(languageModel.getRepetitionPenalty())
|
||||
.temperature(languageModel.getTemperature())
|
||||
.stops(languageModel.getStops())
|
||||
.maxTokens(languageModel.getMaxTokens())
|
||||
.build();
|
||||
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
|
||||
.maxTokens(languageModel.getMaxTokens()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
||||
QwenEmbeddingModel qwenEmbeddingModel(Properties properties) {
|
||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
||||
return QwenEmbeddingModel.builder()
|
||||
.apiKey(embeddingModelProperties.getApiKey())
|
||||
.modelName(embeddingModelProperties.getModelName())
|
||||
.build();
|
||||
return QwenEmbeddingModel.builder().apiKey(embeddingModelProperties.getApiKey())
|
||||
.modelName(embeddingModelProperties.getModelName()).build();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,13 +12,18 @@ public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.dashscope";
|
||||
|
||||
@NestedConfigurationProperty ChatModelProperties chatModel;
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties chatModel;
|
||||
|
||||
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties streamingChatModel;
|
||||
|
||||
@NestedConfigurationProperty ChatModelProperties languageModel;
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties languageModel;
|
||||
|
||||
@NestedConfigurationProperty ChatModelProperties streamingLanguageModel;
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties streamingLanguageModel;
|
||||
|
||||
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingModelProperties embeddingModel;
|
||||
}
|
||||
|
||||
@@ -74,8 +74,8 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
if (MapUtils.isEmpty(super.collectionNameToStore)) {
|
||||
return;
|
||||
}
|
||||
for (Map.Entry<String, EmbeddingStore<TextSegment>> entry :
|
||||
collectionNameToStore.entrySet()) {
|
||||
for (Map.Entry<String, EmbeddingStore<TextSegment>> entry : collectionNameToStore
|
||||
.entrySet()) {
|
||||
Path filePath = getPersistPath(entry.getKey());
|
||||
if (Objects.isNull(filePath)) {
|
||||
continue;
|
||||
|
||||
@@ -12,7 +12,9 @@ public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.in-memory";
|
||||
|
||||
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingStoreProperties embeddingStore;
|
||||
|
||||
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingModelProperties embeddingModel;
|
||||
}
|
||||
|
||||
@@ -20,70 +20,58 @@ public class LocalAiAutoConfig {
|
||||
@ConditionalOnProperty(PREFIX + ".chat-model.base-url")
|
||||
LocalAiChatModel localAiChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
||||
return LocalAiChatModel.builder()
|
||||
.baseUrl(chatModelProperties.getBaseUrl())
|
||||
return LocalAiChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.modelName(chatModelProperties.getModelName())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP())
|
||||
.maxRetries(chatModelProperties.getMaxRetries())
|
||||
.topP(chatModelProperties.getTopP()).maxRetries(chatModelProperties.getMaxRetries())
|
||||
.logRequests(chatModelProperties.getLogRequests())
|
||||
.logResponses(chatModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.base-url")
|
||||
LocalAiStreamingChatModel localAiStreamingChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
||||
return LocalAiStreamingChatModel.builder()
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP())
|
||||
.baseUrl(chatModelProperties.getBaseUrl())
|
||||
return LocalAiStreamingChatModel.builder().temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP()).baseUrl(chatModelProperties.getBaseUrl())
|
||||
.modelName(chatModelProperties.getModelName())
|
||||
.logRequests(chatModelProperties.getLogRequests())
|
||||
.logResponses(chatModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".language-model.base-url")
|
||||
LocalAiLanguageModel localAiLanguageModel(Properties properties) {
|
||||
LanguageModelProperties languageModelProperties = properties.getLanguageModel();
|
||||
return LocalAiLanguageModel.builder()
|
||||
.topP(languageModelProperties.getTopP())
|
||||
return LocalAiLanguageModel.builder().topP(languageModelProperties.getTopP())
|
||||
.baseUrl(languageModelProperties.getBaseUrl())
|
||||
.modelName(languageModelProperties.getModelName())
|
||||
.temperature(languageModelProperties.getTemperature())
|
||||
.maxRetries(languageModelProperties.getMaxRetries())
|
||||
.logRequests(languageModelProperties.getLogRequests())
|
||||
.logResponses(languageModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(languageModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-language-model.base-url")
|
||||
LocalAiStreamingLanguageModel localAiStreamingLanguageModel(Properties properties) {
|
||||
LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel();
|
||||
return LocalAiStreamingLanguageModel.builder()
|
||||
.topP(languageModelProperties.getTopP())
|
||||
return LocalAiStreamingLanguageModel.builder().topP(languageModelProperties.getTopP())
|
||||
.baseUrl(languageModelProperties.getBaseUrl())
|
||||
.modelName(languageModelProperties.getModelName())
|
||||
.temperature(languageModelProperties.getTemperature())
|
||||
.logRequests(languageModelProperties.getLogRequests())
|
||||
.logResponses(languageModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(languageModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".embedding-model.base-url")
|
||||
LocalAiEmbeddingModel localAiEmbeddingModel(Properties properties) {
|
||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
||||
return LocalAiEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelProperties.getBaseUrl())
|
||||
return LocalAiEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
|
||||
.modelName(embeddingModelProperties.getModelName())
|
||||
.maxRetries(embeddingModelProperties.getMaxRetries())
|
||||
.logRequests(embeddingModelProperties.getLogRequests())
|
||||
.logResponses(embeddingModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(embeddingModelProperties.getLogResponses()).build();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,13 +12,18 @@ public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.local-ai";
|
||||
|
||||
@NestedConfigurationProperty ChatModelProperties chatModel;
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties chatModel;
|
||||
|
||||
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties streamingChatModel;
|
||||
|
||||
@NestedConfigurationProperty LanguageModelProperties languageModel;
|
||||
@NestedConfigurationProperty
|
||||
LanguageModelProperties languageModel;
|
||||
|
||||
@NestedConfigurationProperty LanguageModelProperties streamingLanguageModel;
|
||||
@NestedConfigurationProperty
|
||||
LanguageModelProperties streamingLanguageModel;
|
||||
|
||||
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingModelProperties embeddingModel;
|
||||
}
|
||||
|
||||
@@ -29,21 +29,15 @@ public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
|
||||
|
||||
@Override
|
||||
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
|
||||
return MilvusEmbeddingStore.builder()
|
||||
.host(storeProperties.getHost())
|
||||
.port(storeProperties.getPort())
|
||||
.collectionName(collectionName)
|
||||
.dimension(storeProperties.getDimension())
|
||||
.indexType(storeProperties.getIndexType())
|
||||
.metricType(storeProperties.getMetricType())
|
||||
.uri(storeProperties.getUri())
|
||||
.token(storeProperties.getToken())
|
||||
.username(storeProperties.getUsername())
|
||||
return MilvusEmbeddingStore.builder().host(storeProperties.getHost())
|
||||
.port(storeProperties.getPort()).collectionName(collectionName)
|
||||
.dimension(storeProperties.getDimension()).indexType(storeProperties.getIndexType())
|
||||
.metricType(storeProperties.getMetricType()).uri(storeProperties.getUri())
|
||||
.token(storeProperties.getToken()).username(storeProperties.getUsername())
|
||||
.password(storeProperties.getPassword())
|
||||
.consistencyLevel(storeProperties.getConsistencyLevel())
|
||||
.retrieveEmbeddingsOnSearch(storeProperties.getRetrieveEmbeddingsOnSearch())
|
||||
.autoFlushOnInsert(storeProperties.getAutoFlushOnInsert())
|
||||
.databaseName(storeProperties.getDatabaseName())
|
||||
.build();
|
||||
.databaseName(storeProperties.getDatabaseName()).build();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,5 +12,6 @@ public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.milvus";
|
||||
|
||||
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingStoreProperties embeddingStore;
|
||||
}
|
||||
|
||||
@@ -13,10 +13,10 @@ import java.util.Objects;
|
||||
/**
|
||||
* An embedding model that runs within your Java application's process. Any BERT-based model (e.g.,
|
||||
* from HuggingFace) can be used, as long as it is in ONNX format. Information on how to convert
|
||||
* models into ONNX format can be found <a
|
||||
* href="https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>.
|
||||
* Many models already converted to ONNX format are available <a
|
||||
* href="https://huggingface.co/Xenova">here</a>. Copy from
|
||||
* models into ONNX format can be found <a href=
|
||||
* "https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>. Many
|
||||
* models already converted to ONNX format are available
|
||||
* <a href="https://huggingface.co/Xenova">here</a>. Copy from
|
||||
* dev.langchain4j.model.embedding.OnnxEmbeddingModel.
|
||||
*/
|
||||
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
@@ -28,9 +28,8 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
||||
synchronized (S2OnnxEmbeddingModel.class) {
|
||||
if (shouldReloadModel(pathToModel, vocabularyPath)) {
|
||||
URL resource =
|
||||
AbstractInProcessEmbeddingModel.class.getResource(
|
||||
"/bert-vocabulary-en.txt");
|
||||
URL resource = AbstractInProcessEmbeddingModel.class
|
||||
.getResource("/bert-vocabulary-en.txt");
|
||||
if (StringUtils.isNotBlank(vocabularyPath)) {
|
||||
try {
|
||||
resource = Paths.get(vocabularyPath).toUri().toURL();
|
||||
@@ -56,15 +55,14 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
|
||||
}
|
||||
|
||||
private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) {
|
||||
return cachedModel == null
|
||||
|| !Objects.equals(cachedModelPath, pathToModel)
|
||||
return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel)
|
||||
|| !Objects.equals(cachedVocabularyPath, vocabularyPath);
|
||||
}
|
||||
|
||||
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
|
||||
try {
|
||||
return new OnnxBertBiEncoder(
|
||||
Files.newInputStream(pathToModel), vocabularyFile, PoolingMode.MEAN);
|
||||
return new OnnxBertBiEncoder(Files.newInputStream(pathToModel), vocabularyFile,
|
||||
PoolingMode.MEAN);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
@@ -60,8 +60,8 @@ import static java.util.Collections.singletonList;
|
||||
|
||||
/**
|
||||
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and
|
||||
* gpt-4. You can find description of parameters <a
|
||||
* href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
|
||||
* gpt-4. You can find description of parameters
|
||||
* <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
|
||||
*/
|
||||
@Slf4j
|
||||
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
@@ -88,32 +88,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
private final List<ChatModelListener> listeners;
|
||||
|
||||
@Builder
|
||||
public OpenAiChatModel(
|
||||
String baseUrl,
|
||||
String apiKey,
|
||||
String organizationId,
|
||||
String modelName,
|
||||
Double temperature,
|
||||
Double topP,
|
||||
List<String> stop,
|
||||
Integer maxTokens,
|
||||
Double presencePenalty,
|
||||
Double frequencyPenalty,
|
||||
Map<String, Integer> logitBias,
|
||||
String responseFormat,
|
||||
Boolean strictJsonSchema,
|
||||
Integer seed,
|
||||
String user,
|
||||
Boolean strictTools,
|
||||
Boolean parallelToolCalls,
|
||||
Duration timeout,
|
||||
Integer maxRetries,
|
||||
Proxy proxy,
|
||||
Boolean logRequests,
|
||||
Boolean logResponses,
|
||||
Tokenizer tokenizer,
|
||||
Map<String, String> customHeaders,
|
||||
List<ChatModelListener> listeners) {
|
||||
public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName,
|
||||
Double temperature, Double topP, List<String> stop, Integer maxTokens,
|
||||
Double presencePenalty, Double frequencyPenalty, Map<String, Integer> logitBias,
|
||||
String responseFormat, Boolean strictJsonSchema, Integer seed, String user,
|
||||
Boolean strictTools, Boolean parallelToolCalls, Duration timeout, Integer maxRetries,
|
||||
Proxy proxy, Boolean logRequests, Boolean logResponses, Tokenizer tokenizer,
|
||||
Map<String, String> customHeaders, List<ChatModelListener> listeners) {
|
||||
|
||||
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
|
||||
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
|
||||
@@ -123,21 +104,11 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
|
||||
timeout = getOrDefault(timeout, ofSeconds(60));
|
||||
|
||||
this.client =
|
||||
OpenAiClient.builder()
|
||||
.openAiApiKey(apiKey)
|
||||
.baseUrl(baseUrl)
|
||||
.organizationId(organizationId)
|
||||
.callTimeout(timeout)
|
||||
.connectTimeout(timeout)
|
||||
.readTimeout(timeout)
|
||||
.writeTimeout(timeout)
|
||||
.proxy(proxy)
|
||||
.logRequests(logRequests)
|
||||
.logResponses(logResponses)
|
||||
.userAgent(DEFAULT_USER_AGENT)
|
||||
.customHeaders(customHeaders)
|
||||
.build();
|
||||
this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl)
|
||||
.organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout)
|
||||
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests)
|
||||
.logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
|
||||
.customHeaders(customHeaders).build();
|
||||
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
|
||||
this.temperature = getOrDefault(temperature, 0.7);
|
||||
this.topP = topP;
|
||||
@@ -146,14 +117,10 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
this.presencePenalty = presencePenalty;
|
||||
this.frequencyPenalty = frequencyPenalty;
|
||||
this.logitBias = logitBias;
|
||||
this.responseFormat =
|
||||
responseFormat == null
|
||||
? null
|
||||
: ResponseFormat.builder()
|
||||
.type(
|
||||
ResponseFormatType.valueOf(
|
||||
responseFormat.toUpperCase(Locale.ROOT)))
|
||||
.build();
|
||||
this.responseFormat = responseFormat == null ? null
|
||||
: ResponseFormat.builder()
|
||||
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
|
||||
.build();
|
||||
this.strictJsonSchema = getOrDefault(strictJsonSchema, false);
|
||||
this.seed = seed;
|
||||
this.user = user;
|
||||
@@ -183,61 +150,44 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(
|
||||
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
return generate(messages, toolSpecifications, null, this.responseFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(
|
||||
List<ChatMessage> messages, ToolSpecification toolSpecification) {
|
||||
return generate(
|
||||
messages, singletonList(toolSpecification), toolSpecification, this.responseFormat);
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
ToolSpecification toolSpecification) {
|
||||
return generate(messages, singletonList(toolSpecification), toolSpecification,
|
||||
this.responseFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatResponse chat(ChatRequest request) {
|
||||
Response<AiMessage> response =
|
||||
generate(
|
||||
request.messages(),
|
||||
request.toolSpecifications(),
|
||||
null,
|
||||
generate(request.messages(), request.toolSpecifications(), null,
|
||||
getOrDefault(
|
||||
toOpenAiResponseFormat(request.responseFormat(), strictJsonSchema),
|
||||
this.responseFormat));
|
||||
return ChatResponse.builder()
|
||||
.aiMessage(response.content())
|
||||
.tokenUsage(response.tokenUsage())
|
||||
.finishReason(response.finishReason())
|
||||
.build();
|
||||
return ChatResponse.builder().aiMessage(response.content())
|
||||
.tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).build();
|
||||
}
|
||||
|
||||
private Response<AiMessage> generate(
|
||||
List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications,
|
||||
ToolSpecification toolThatMustBeExecuted,
|
||||
private Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted,
|
||||
ResponseFormat responseFormat) {
|
||||
|
||||
if (responseFormat != null
|
||||
&& responseFormat.type() == JSON_SCHEMA
|
||||
if (responseFormat != null && responseFormat.type() == JSON_SCHEMA
|
||||
&& responseFormat.jsonSchema() == null) {
|
||||
responseFormat = null;
|
||||
}
|
||||
|
||||
ChatCompletionRequest.Builder requestBuilder =
|
||||
ChatCompletionRequest.builder()
|
||||
.model(modelName)
|
||||
.messages(toOpenAiMessages(messages))
|
||||
.topP(topP)
|
||||
.stop(stop)
|
||||
.maxTokens(maxTokens)
|
||||
.presencePenalty(presencePenalty)
|
||||
.frequencyPenalty(frequencyPenalty)
|
||||
.logitBias(logitBias)
|
||||
.responseFormat(responseFormat)
|
||||
.seed(seed)
|
||||
.user(user)
|
||||
.parallelToolCalls(parallelToolCalls);
|
||||
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
|
||||
.model(modelName).messages(toOpenAiMessages(messages)).topP(topP).stop(stop)
|
||||
.maxTokens(maxTokens).presencePenalty(presencePenalty)
|
||||
.frequencyPenalty(frequencyPenalty).logitBias(logitBias)
|
||||
.responseFormat(responseFormat).seed(seed).user(user)
|
||||
.parallelToolCalls(parallelToolCalls);
|
||||
|
||||
if (!(baseUrl.contains(ZHIPU))) {
|
||||
requestBuilder.temperature(temperature);
|
||||
@@ -257,40 +207,33 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
Map<Object, Object> attributes = new ConcurrentHashMap<>();
|
||||
ChatModelRequestContext requestContext =
|
||||
new ChatModelRequestContext(modelListenerRequest, attributes);
|
||||
listeners.forEach(
|
||||
listener -> {
|
||||
try {
|
||||
listener.onRequest(requestContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onRequest(requestContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
|
||||
try {
|
||||
ChatCompletionResponse chatCompletionResponse =
|
||||
withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
|
||||
|
||||
Response<AiMessage> response =
|
||||
Response.from(
|
||||
aiMessageFrom(chatCompletionResponse),
|
||||
tokenUsageFrom(chatCompletionResponse.usage()),
|
||||
finishReasonFrom(
|
||||
chatCompletionResponse.choices().get(0).finishReason()));
|
||||
Response<AiMessage> response = Response.from(aiMessageFrom(chatCompletionResponse),
|
||||
tokenUsageFrom(chatCompletionResponse.usage()),
|
||||
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason()));
|
||||
|
||||
ChatModelResponse modelListenerResponse =
|
||||
createModelListenerResponse(
|
||||
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
|
||||
ChatModelResponseContext responseContext =
|
||||
new ChatModelResponseContext(
|
||||
modelListenerResponse, modelListenerRequest, attributes);
|
||||
listeners.forEach(
|
||||
listener -> {
|
||||
try {
|
||||
listener.onResponse(responseContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
ChatModelResponse modelListenerResponse = createModelListenerResponse(
|
||||
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
|
||||
ChatModelResponseContext responseContext = new ChatModelResponseContext(
|
||||
modelListenerResponse, modelListenerRequest, attributes);
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onResponse(responseContext);
|
||||
} catch (Exception e) {
|
||||
log.warn("Exception while calling model listener", e);
|
||||
}
|
||||
});
|
||||
|
||||
return response;
|
||||
} catch (RuntimeException e) {
|
||||
@@ -305,14 +248,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
ChatModelErrorContext errorContext =
|
||||
new ChatModelErrorContext(error, modelListenerRequest, null, attributes);
|
||||
|
||||
listeners.forEach(
|
||||
listener -> {
|
||||
try {
|
||||
listener.onError(errorContext);
|
||||
} catch (Exception e2) {
|
||||
log.warn("Exception while calling model listener", e2);
|
||||
}
|
||||
});
|
||||
listeners.forEach(listener -> {
|
||||
try {
|
||||
listener.onError(errorContext);
|
||||
} catch (Exception e2) {
|
||||
log.warn("Exception while calling model listener", e2);
|
||||
}
|
||||
});
|
||||
|
||||
throw e;
|
||||
}
|
||||
@@ -328,8 +270,8 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
}
|
||||
|
||||
public static OpenAiChatModelBuilder builder() {
|
||||
for (OpenAiChatModelBuilderFactory factory :
|
||||
loadFactories(OpenAiChatModelBuilderFactory.class)) {
|
||||
for (OpenAiChatModelBuilderFactory factory : loadFactories(
|
||||
OpenAiChatModelBuilderFactory.class)) {
|
||||
return factory.get();
|
||||
}
|
||||
return new OpenAiChatModelBuilder();
|
||||
|
||||
@@ -3,9 +3,8 @@ package dev.langchain4j.model.openai;
|
||||
public enum OpenAiChatModelName {
|
||||
GPT_3_5_TURBO("gpt-3.5-turbo"), // alias
|
||||
@Deprecated
|
||||
GPT_3_5_TURBO_0613("gpt-3.5-turbo-0613"),
|
||||
GPT_3_5_TURBO_1106("gpt-3.5-turbo-1106"),
|
||||
GPT_3_5_TURBO_0125("gpt-3.5-turbo-0125"),
|
||||
GPT_3_5_TURBO_0613("gpt-3.5-turbo-0613"), GPT_3_5_TURBO_1106(
|
||||
"gpt-3.5-turbo-1106"), GPT_3_5_TURBO_0125("gpt-3.5-turbo-0125"),
|
||||
|
||||
GPT_3_5_TURBO_16K("gpt-3.5-turbo-16k"), // alias
|
||||
@Deprecated
|
||||
@@ -13,22 +12,18 @@ public enum OpenAiChatModelName {
|
||||
|
||||
GPT_4("gpt-4"), // alias
|
||||
@Deprecated
|
||||
GPT_4_0314("gpt-4-0314"),
|
||||
GPT_4_0613("gpt-4-0613"),
|
||||
GPT_4_0314("gpt-4-0314"), GPT_4_0613("gpt-4-0613"),
|
||||
|
||||
GPT_4_TURBO_PREVIEW("gpt-4-turbo-preview"), // alias
|
||||
GPT_4_1106_PREVIEW("gpt-4-1106-preview"),
|
||||
GPT_4_0125_PREVIEW("gpt-4-0125-preview"),
|
||||
GPT_4_1106_PREVIEW("gpt-4-1106-preview"), GPT_4_0125_PREVIEW("gpt-4-0125-preview"),
|
||||
|
||||
GPT_4_32K("gpt-4-32k"), // alias
|
||||
GPT_4_32K_0314("gpt-4-32k-0314"),
|
||||
GPT_4_32K_0613("gpt-4-32k-0613"),
|
||||
GPT_4_32K_0314("gpt-4-32k-0314"), GPT_4_32K_0613("gpt-4-32k-0613"),
|
||||
|
||||
@Deprecated
|
||||
GPT_4_VISION_PREVIEW("gpt-4-vision-preview"),
|
||||
|
||||
GPT_4_O("gpt-4o"),
|
||||
GPT_4_O_MINI("gpt-4o-mini");
|
||||
GPT_4_O("gpt-4o"), GPT_4_O_MINI("gpt-4o-mini");
|
||||
|
||||
private final String stringValue;
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package dev.langchain4j.model.zhipu;
|
||||
|
||||
public enum ChatCompletionModel {
|
||||
GLM_4("glm-4"),
|
||||
GLM_3_TURBO("glm-3-turbo"),
|
||||
CHATGLM_TURBO("chatglm_turbo");
|
||||
GLM_4("glm-4"), GLM_3_TURBO("glm-3-turbo"), CHATGLM_TURBO("chatglm_turbo");
|
||||
|
||||
private final String value;
|
||||
|
||||
|
||||
@@ -27,8 +27,8 @@ import static java.util.Collections.singletonList;
|
||||
|
||||
/**
|
||||
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and
|
||||
* glm-4. You can find description of parameters <a
|
||||
* href="https://open.bigmodel.cn/dev/api">here</a>.
|
||||
* glm-4. You can find description of parameters
|
||||
* <a href="https://open.bigmodel.cn/dev/api">here</a>.
|
||||
*/
|
||||
public class ZhipuAiChatModel implements ChatLanguageModel {
|
||||
|
||||
@@ -41,15 +41,8 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
|
||||
private final ZhipuAiClient client;
|
||||
|
||||
@Builder
|
||||
public ZhipuAiChatModel(
|
||||
String baseUrl,
|
||||
String apiKey,
|
||||
Double temperature,
|
||||
Double topP,
|
||||
String model,
|
||||
Integer maxRetries,
|
||||
Integer maxToken,
|
||||
Boolean logRequests,
|
||||
public ZhipuAiChatModel(String baseUrl, String apiKey, Double temperature, Double topP,
|
||||
String model, Integer maxRetries, Integer maxToken, Boolean logRequests,
|
||||
Boolean logResponses) {
|
||||
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
|
||||
this.temperature = getOrDefault(temperature, 0.7);
|
||||
@@ -57,18 +50,14 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
|
||||
this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.maxToken = getOrDefault(maxToken, 512);
|
||||
this.client =
|
||||
ZhipuAiClient.builder()
|
||||
.baseUrl(this.baseUrl)
|
||||
.apiKey(apiKey)
|
||||
.logRequests(getOrDefault(logRequests, false))
|
||||
.logResponses(getOrDefault(logResponses, false))
|
||||
.build();
|
||||
this.client = ZhipuAiClient.builder().baseUrl(this.baseUrl).apiKey(apiKey)
|
||||
.logRequests(getOrDefault(logRequests, false))
|
||||
.logResponses(getOrDefault(logResponses, false)).build();
|
||||
}
|
||||
|
||||
public static ZhipuAiChatModelBuilder builder() {
|
||||
for (ZhipuAiChatModelBuilderFactory factories :
|
||||
loadFactories(ZhipuAiChatModelBuilderFactory.class)) {
|
||||
for (ZhipuAiChatModelBuilderFactory factories : loadFactories(
|
||||
ZhipuAiChatModelBuilderFactory.class)) {
|
||||
return factories.get();
|
||||
}
|
||||
return new ZhipuAiChatModelBuilder();
|
||||
@@ -80,15 +69,13 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(
|
||||
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
ensureNotEmpty(messages, "messages");
|
||||
|
||||
ChatCompletionRequest.Builder requestBuilder =
|
||||
ChatCompletionRequest.builder().model(this.model).maxTokens(maxToken).stream(false)
|
||||
.topP(topP)
|
||||
.toolChoice(AUTO)
|
||||
.messages(toZhipuAiMessages(messages));
|
||||
.topP(topP).toolChoice(AUTO).messages(toZhipuAiMessages(messages));
|
||||
|
||||
if (!isNullOrEmpty(toolSpecifications)) {
|
||||
requestBuilder.tools(toTools(toolSpecifications));
|
||||
@@ -96,17 +83,15 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
|
||||
|
||||
ChatCompletionResponse response =
|
||||
withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
|
||||
return Response.from(
|
||||
aiMessageFrom(response),
|
||||
tokenUsageFrom(response.getUsage()),
|
||||
return Response.from(aiMessageFrom(response), tokenUsageFrom(response.getUsage()),
|
||||
finishReasonFrom(response.getChoices().get(0).getFinishReason()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(
|
||||
List<ChatMessage> messages, ToolSpecification toolSpecification) {
|
||||
return generate(
|
||||
messages, toolSpecification != null ? singletonList(toolSpecification) : null);
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
ToolSpecification toolSpecification) {
|
||||
return generate(messages,
|
||||
toolSpecification != null ? singletonList(toolSpecification) : null);
|
||||
}
|
||||
|
||||
public static class ZhipuAiChatModelBuilder {
|
||||
|
||||
@@ -20,36 +20,27 @@ public class AzureModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
AzureOpenAiChatModel.Builder builder =
|
||||
AzureOpenAiChatModel.builder()
|
||||
.endpoint(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey())
|
||||
.deploymentName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.topP(modelConfig.getTopP())
|
||||
.timeout(
|
||||
Duration.ofSeconds(
|
||||
modelConfig.getTimeOut() == null
|
||||
? 0L
|
||||
: modelConfig.getTimeOut()))
|
||||
.logRequestsAndResponses(
|
||||
modelConfig.getLogRequests() != null
|
||||
&& modelConfig.getLogResponses());
|
||||
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
||||
.endpoint(modelConfig.getBaseUrl()).apiKey(modelConfig.getApiKey())
|
||||
.deploymentName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature()).maxRetries(modelConfig.getMaxRetries())
|
||||
.topP(modelConfig.getTopP())
|
||||
.timeout(Duration.ofSeconds(
|
||||
modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
|
||||
.logRequestsAndResponses(
|
||||
modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
AzureOpenAiEmbeddingModel.Builder builder =
|
||||
AzureOpenAiEmbeddingModel.builder()
|
||||
.endpoint(embeddingModelConfig.getBaseUrl())
|
||||
AzureOpenAiEmbeddingModel.builder().endpoint(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.deploymentName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequestsAndResponses(
|
||||
embeddingModelConfig.getLogRequests() != null
|
||||
&& embeddingModelConfig.getLogResponses());
|
||||
.logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null
|
||||
&& embeddingModelConfig.getLogResponses());
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -19,25 +19,17 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return QwenChatModel.builder()
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(
|
||||
modelConfig.getTemperature() == null
|
||||
? 0L
|
||||
: modelConfig.getTemperature().floatValue())
|
||||
.topP(modelConfig.getTopP())
|
||||
.enableSearch(modelConfig.getEnableSearch())
|
||||
.build();
|
||||
return QwenChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey()).modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature() == null ? 0L
|
||||
: modelConfig.getTemperature().floatValue())
|
||||
.topP(modelConfig.getTopP()).enableSearch(modelConfig.getEnableSearch()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return QwenEmbeddingModel.builder()
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.build();
|
||||
return QwenEmbeddingModel.builder().apiKey(embeddingModelConfig.getApiKey())
|
||||
.modelName(embeddingModelConfig.getModelName()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -19,27 +19,20 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return LocalAiChatModel.builder()
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.topP(modelConfig.getTopP())
|
||||
return LocalAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName()).temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut())).topP(modelConfig.getTopP())
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.logResponses(modelConfig.getLogResponses()).maxRetries(modelConfig.getMaxRetries())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||
return LocalAiEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModel.getBaseUrl())
|
||||
.modelName(embeddingModel.getModelName())
|
||||
.maxRetries(embeddingModel.getMaxRetries())
|
||||
return LocalAiEmbeddingModel.builder().baseUrl(embeddingModel.getBaseUrl())
|
||||
.modelName(embeddingModel.getModelName()).maxRetries(embeddingModel.getMaxRetries())
|
||||
.logRequests(embeddingModel.getLogRequests())
|
||||
.logResponses(embeddingModel.getLogResponses())
|
||||
.build();
|
||||
.logResponses(embeddingModel.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -25,8 +25,7 @@ public class ModelProvider {
|
||||
}
|
||||
|
||||
public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) {
|
||||
if (modelConfig == null
|
||||
|| StringUtils.isBlank(modelConfig.getProvider())
|
||||
if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider())
|
||||
|| StringUtils.isBlank(modelConfig.getBaseUrl())) {
|
||||
ChatModelParameterConfig parameterConfig =
|
||||
ContextUtils.getBean(ChatModelParameterConfig.class);
|
||||
|
||||
@@ -21,27 +21,20 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return OllamaChatModel.builder()
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses())
|
||||
.build();
|
||||
return OllamaChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName()).temperature(modelConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut())).topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return OllamaEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
return OllamaEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses())
|
||||
.build();
|
||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -21,29 +21,22 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return OpenAiChatModel.builder()
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.apiKey(modelConfig.keyDecrypt())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.topP(modelConfig.getTopP())
|
||||
return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
|
||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses())
|
||||
.build();
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||
return OpenAiEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModel.getBaseUrl())
|
||||
.apiKey(embeddingModel.getApiKey())
|
||||
.modelName(embeddingModel.getModelName())
|
||||
return OpenAiEmbeddingModel.builder().baseUrl(embeddingModel.getBaseUrl())
|
||||
.apiKey(embeddingModel.getApiKey()).modelName(embeddingModel.getModelName())
|
||||
.maxRetries(embeddingModel.getMaxRetries())
|
||||
.logRequests(embeddingModel.getLogRequests())
|
||||
.logResponses(embeddingModel.getLogResponses())
|
||||
.build();
|
||||
.logResponses(embeddingModel.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -21,31 +21,23 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return QianfanChatModel.builder()
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey())
|
||||
.secretKey(modelConfig.getSecretKey())
|
||||
.endpoint(modelConfig.getEndpoint())
|
||||
.modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses())
|
||||
.build();
|
||||
return QianfanChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey()).secretKey(modelConfig.getSecretKey())
|
||||
.endpoint(modelConfig.getEndpoint()).modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return QianfanEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.secretKey(embeddingModelConfig.getSecretKey())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses())
|
||||
.build();
|
||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -19,28 +19,20 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return ZhipuAiChatModel.builder()
|
||||
.baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey())
|
||||
.model(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature())
|
||||
.topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries())
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses())
|
||||
.build();
|
||||
return ZhipuAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey()).model(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return ZhipuAiEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.model(embeddingModelConfig.getModelName())
|
||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses())
|
||||
.build();
|
||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -12,13 +12,18 @@ public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.qianfan";
|
||||
|
||||
@NestedConfigurationProperty ChatModelProperties chatModel;
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties chatModel;
|
||||
|
||||
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties streamingChatModel;
|
||||
|
||||
@NestedConfigurationProperty LanguageModelProperties languageModel;
|
||||
@NestedConfigurationProperty
|
||||
LanguageModelProperties languageModel;
|
||||
|
||||
@NestedConfigurationProperty LanguageModelProperties streamingLanguageModel;
|
||||
@NestedConfigurationProperty
|
||||
LanguageModelProperties streamingLanguageModel;
|
||||
|
||||
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingModelProperties embeddingModel;
|
||||
}
|
||||
|
||||
@@ -20,8 +20,7 @@ public class QianfanAutoConfig {
|
||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
||||
QianfanChatModel qianfanChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
||||
return QianfanChatModel.builder()
|
||||
.baseUrl(chatModelProperties.getBaseUrl())
|
||||
return QianfanChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.secretKey(chatModelProperties.getSecretKey())
|
||||
.endpoint(chatModelProperties.getEndpoint())
|
||||
@@ -32,38 +31,32 @@ public class QianfanAutoConfig {
|
||||
.responseFormat(chatModelProperties.getResponseFormat())
|
||||
.maxRetries(chatModelProperties.getMaxRetries())
|
||||
.logRequests(chatModelProperties.getLogRequests())
|
||||
.logResponses(chatModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
||||
QianfanStreamingChatModel qianfanStreamingChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
||||
return QianfanStreamingChatModel.builder()
|
||||
.endpoint(chatModelProperties.getEndpoint())
|
||||
return QianfanStreamingChatModel.builder().endpoint(chatModelProperties.getEndpoint())
|
||||
.penaltyScore(chatModelProperties.getPenaltyScore())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP())
|
||||
.baseUrl(chatModelProperties.getBaseUrl())
|
||||
.topP(chatModelProperties.getTopP()).baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.secretKey(chatModelProperties.getSecretKey())
|
||||
.modelName(chatModelProperties.getModelName())
|
||||
.responseFormat(chatModelProperties.getResponseFormat())
|
||||
.logRequests(chatModelProperties.getLogRequests())
|
||||
.logResponses(chatModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
|
||||
QianfanLanguageModel qianfanLanguageModel(Properties properties) {
|
||||
LanguageModelProperties languageModelProperties = properties.getLanguageModel();
|
||||
return QianfanLanguageModel.builder()
|
||||
.endpoint(languageModelProperties.getEndpoint())
|
||||
return QianfanLanguageModel.builder().endpoint(languageModelProperties.getEndpoint())
|
||||
.penaltyScore(languageModelProperties.getPenaltyScore())
|
||||
.topK(languageModelProperties.getTopK())
|
||||
.topP(languageModelProperties.getTopP())
|
||||
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
|
||||
.baseUrl(languageModelProperties.getBaseUrl())
|
||||
.apiKey(languageModelProperties.getApiKey())
|
||||
.secretKey(languageModelProperties.getSecretKey())
|
||||
@@ -71,8 +64,7 @@ public class QianfanAutoConfig {
|
||||
.temperature(languageModelProperties.getTemperature())
|
||||
.maxRetries(languageModelProperties.getMaxRetries())
|
||||
.logRequests(languageModelProperties.getLogRequests())
|
||||
.logResponses(languageModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(languageModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@@ -82,8 +74,7 @@ public class QianfanAutoConfig {
|
||||
return QianfanStreamingLanguageModel.builder()
|
||||
.endpoint(languageModelProperties.getEndpoint())
|
||||
.penaltyScore(languageModelProperties.getPenaltyScore())
|
||||
.topK(languageModelProperties.getTopK())
|
||||
.topP(languageModelProperties.getTopP())
|
||||
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
|
||||
.baseUrl(languageModelProperties.getBaseUrl())
|
||||
.apiKey(languageModelProperties.getApiKey())
|
||||
.secretKey(languageModelProperties.getSecretKey())
|
||||
@@ -91,16 +82,14 @@ public class QianfanAutoConfig {
|
||||
.temperature(languageModelProperties.getTemperature())
|
||||
.maxRetries(languageModelProperties.getMaxRetries())
|
||||
.logRequests(languageModelProperties.getLogRequests())
|
||||
.logResponses(languageModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(languageModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
||||
QianfanEmbeddingModel qianfanEmbeddingModel(Properties properties) {
|
||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
||||
return QianfanEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelProperties.getBaseUrl())
|
||||
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
|
||||
.endpoint(embeddingModelProperties.getEndpoint())
|
||||
.apiKey(embeddingModelProperties.getApiKey())
|
||||
.secretKey(embeddingModelProperties.getSecretKey())
|
||||
@@ -108,7 +97,6 @@ public class QianfanAutoConfig {
|
||||
.user(embeddingModelProperties.getUser())
|
||||
.maxRetries(embeddingModelProperties.getMaxRetries())
|
||||
.logRequests(embeddingModelProperties.getLogRequests())
|
||||
.logResponses(embeddingModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(embeddingModelProperties.getLogResponses()).build();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,24 +27,19 @@ public class EmbeddingStoreFactoryProvider {
|
||||
return ContextUtils.getBean(EmbeddingStoreFactory.class);
|
||||
}
|
||||
if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||
return factoryMap.computeIfAbsent(
|
||||
embeddingStoreConfig,
|
||||
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||
storeConfig -> new ChromaEmbeddingStoreFactory(storeConfig));
|
||||
}
|
||||
if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||
return factoryMap.computeIfAbsent(
|
||||
embeddingStoreConfig,
|
||||
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||
storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig));
|
||||
}
|
||||
if (EmbeddingStoreType.IN_MEMORY
|
||||
.name()
|
||||
if (EmbeddingStoreType.IN_MEMORY.name()
|
||||
.equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
|
||||
return factoryMap.computeIfAbsent(
|
||||
embeddingStoreConfig,
|
||||
return factoryMap.computeIfAbsent(embeddingStoreConfig,
|
||||
storeConfig -> new InMemoryEmbeddingStoreFactory(storeConfig));
|
||||
}
|
||||
throw new RuntimeException(
|
||||
"Unsupported EmbeddingStoreFactory provider: "
|
||||
+ embeddingStoreConfig.getProvider());
|
||||
throw new RuntimeException("Unsupported EmbeddingStoreFactory provider: "
|
||||
+ embeddingStoreConfig.getProvider());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
public enum EmbeddingStoreType {
|
||||
IN_MEMORY,
|
||||
MILVUS,
|
||||
CHROMA
|
||||
IN_MEMORY, MILVUS, CHROMA
|
||||
}
|
||||
|
||||
@@ -36,8 +36,7 @@ public class Retrieval {
|
||||
}
|
||||
Retrieval retrieval = (Retrieval) o;
|
||||
return Double.compare(retrieval.similarity, similarity) == 0
|
||||
&& Objects.equal(id, retrieval.id)
|
||||
&& Objects.equal(query, retrieval.query)
|
||||
&& Objects.equal(id, retrieval.id) && Objects.equal(query, retrieval.query)
|
||||
&& Objects.equal(metadata, retrieval.metadata);
|
||||
}
|
||||
|
||||
|
||||
@@ -17,20 +17,12 @@ public class TextSegmentConvert {
|
||||
public static final String QUERY_ID = "queryId";
|
||||
|
||||
public static List<TextSegment> convertToEmbedding(List<DataItem> dataItems) {
|
||||
return dataItems.stream()
|
||||
.map(
|
||||
dataItem -> {
|
||||
Map meta =
|
||||
JSONObject.parseObject(
|
||||
JSONObject.toJSONString(dataItem), Map.class);
|
||||
TextSegment textSegment =
|
||||
TextSegment.from(dataItem.getName(), new Metadata(meta));
|
||||
addQueryId(
|
||||
textSegment,
|
||||
dataItem.getId() + dataItem.getType().name().toLowerCase());
|
||||
return textSegment;
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
return dataItems.stream().map(dataItem -> {
|
||||
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
||||
TextSegment textSegment = TextSegment.from(dataItem.getName(), new Metadata(meta));
|
||||
addQueryId(textSegment, dataItem.getId() + dataItem.getType().name().toLowerCase());
|
||||
return textSegment;
|
||||
}).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static void addQueryId(TextSegment textSegment, String queryId) {
|
||||
|
||||
@@ -40,16 +40,19 @@ import static java.util.stream.Collectors.toList;
|
||||
/**
|
||||
* An {@link EmbeddingStore} that stores embeddings in memory.
|
||||
*
|
||||
* <p>Uses a brute force approach by iterating over all embeddings to find the best matches.
|
||||
* <p>
|
||||
* Uses a brute force approach by iterating over all embeddings to find the best matches.
|
||||
*
|
||||
* <p>This store can be persisted using the {@link #serializeToJson()} and {@link
|
||||
* #serializeToFile(Path)} methods.
|
||||
* <p>
|
||||
* This store can be persisted using the {@link #serializeToJson()} and
|
||||
* {@link #serializeToFile(Path)} methods.
|
||||
*
|
||||
* <p>It can also be recreated from JSON or a file using the {@link #fromJson(String)} and {@link
|
||||
* #fromFile(Path)} methods.
|
||||
* <p>
|
||||
* It can also be recreated from JSON or a file using the {@link #fromJson(String)} and
|
||||
* {@link #fromFile(Path)} methods.
|
||||
*
|
||||
* @param <Embedded> The class of the object that has been embedded. Typically, it is {@link
|
||||
* dev.langchain4j.data.segment.TextSegment}.
|
||||
* @param <Embedded> The class of the object that has been embedded. Typically, it is
|
||||
* {@link dev.langchain4j.data.segment.TextSegment}.
|
||||
*/
|
||||
public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
|
||||
|
||||
@@ -88,10 +91,8 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
|
||||
@Override
|
||||
public List<String> addAll(List<Embedding> embeddings) {
|
||||
|
||||
List<Entry<Embedded>> newEntries =
|
||||
embeddings.stream()
|
||||
.map(embedding -> new Entry<Embedded>(randomUUID(), embedding))
|
||||
.collect(toList());
|
||||
List<Entry<Embedded>> newEntries = embeddings.stream()
|
||||
.map(embedding -> new Entry<Embedded>(randomUUID(), embedding)).collect(toList());
|
||||
|
||||
return add(newEntries);
|
||||
}
|
||||
@@ -103,11 +104,9 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
|
||||
"The list of embeddings and embedded must have the same size");
|
||||
}
|
||||
|
||||
List<Entry<Embedded>> newEntries =
|
||||
IntStream.range(0, embeddings.size())
|
||||
.mapToObj(
|
||||
i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i)))
|
||||
.collect(toList());
|
||||
List<Entry<Embedded>> newEntries = IntStream.range(0, embeddings.size())
|
||||
.mapToObj(i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i)))
|
||||
.collect(toList());
|
||||
|
||||
return add(newEntries);
|
||||
}
|
||||
@@ -123,16 +122,15 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
|
||||
public void removeAll(Filter filter) {
|
||||
ensureNotNull(filter, "filter");
|
||||
|
||||
entries.removeIf(
|
||||
entry -> {
|
||||
if (entry.embedded instanceof TextSegment) {
|
||||
return filter.test(((TextSegment) entry.embedded).metadata());
|
||||
} else if (entry.embedded == null) {
|
||||
return false;
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Not supported yet.");
|
||||
}
|
||||
});
|
||||
entries.removeIf(entry -> {
|
||||
if (entry.embedded instanceof TextSegment) {
|
||||
return filter.test(((TextSegment) entry.embedded).metadata());
|
||||
} else if (entry.embedded == null) {
|
||||
return false;
|
||||
} else {
|
||||
throw new UnsupportedOperationException("Not supported yet.");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -157,9 +155,8 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
|
||||
}
|
||||
}
|
||||
|
||||
double cosineSimilarity =
|
||||
CosineSimilarity.between(
|
||||
entry.embedding, embeddingSearchRequest.queryEmbedding());
|
||||
double cosineSimilarity = CosineSimilarity.between(entry.embedding,
|
||||
embeddingSearchRequest.queryEmbedding());
|
||||
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
|
||||
if (score >= embeddingSearchRequest.minScore()) {
|
||||
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
|
||||
@@ -247,8 +244,8 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
|
||||
}
|
||||
|
||||
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
|
||||
for (InMemoryEmbeddingStoreJsonCodecFactory factory :
|
||||
loadFactories(InMemoryEmbeddingStoreJsonCodecFactory.class)) {
|
||||
for (InMemoryEmbeddingStoreJsonCodecFactory factory : loadFactories(
|
||||
InMemoryEmbeddingStoreJsonCodecFactory.class)) {
|
||||
return factory.create();
|
||||
}
|
||||
return new GsonInMemoryEmbeddingStoreJsonCodec();
|
||||
|
||||
@@ -58,27 +58,13 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
private final boolean retrieveEmbeddingsOnSearch;
|
||||
private final boolean autoFlushOnInsert;
|
||||
|
||||
public MilvusEmbeddingStore(
|
||||
String host,
|
||||
Integer port,
|
||||
String collectionName,
|
||||
Integer dimension,
|
||||
IndexType indexType,
|
||||
MetricType metricType,
|
||||
String uri,
|
||||
String token,
|
||||
String username,
|
||||
String password,
|
||||
ConsistencyLevelEnum consistencyLevel,
|
||||
Boolean retrieveEmbeddingsOnSearch,
|
||||
Boolean autoFlushOnInsert,
|
||||
String databaseName) {
|
||||
public MilvusEmbeddingStore(String host, Integer port, String collectionName, Integer dimension,
|
||||
IndexType indexType, MetricType metricType, String uri, String token, String username,
|
||||
String password, ConsistencyLevelEnum consistencyLevel,
|
||||
Boolean retrieveEmbeddingsOnSearch, Boolean autoFlushOnInsert, String databaseName) {
|
||||
ConnectParam.Builder connectBuilder =
|
||||
ConnectParam.newBuilder()
|
||||
.withHost(getOrDefault(host, "localhost"))
|
||||
.withPort(getOrDefault(port, 19530))
|
||||
.withUri(uri)
|
||||
.withToken(token)
|
||||
ConnectParam.newBuilder().withHost(getOrDefault(host, "localhost"))
|
||||
.withPort(getOrDefault(port, 19530)).withUri(uri).withToken(token)
|
||||
.withAuthorization(getOrDefault(username, ""), getOrDefault(password, ""));
|
||||
|
||||
if (databaseName != null) {
|
||||
@@ -93,12 +79,9 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false);
|
||||
|
||||
if (!hasCollection(this.milvusClient, this.collectionName)) {
|
||||
createCollection(
|
||||
this.milvusClient, this.collectionName, ensureNotNull(dimension, "dimension"));
|
||||
createIndex(
|
||||
this.milvusClient,
|
||||
this.collectionName,
|
||||
getOrDefault(indexType, FLAT),
|
||||
createCollection(this.milvusClient, this.collectionName,
|
||||
ensureNotNull(dimension, "dimension"));
|
||||
createIndex(this.milvusClient, this.collectionName, getOrDefault(indexType, FLAT),
|
||||
this.metricType);
|
||||
}
|
||||
|
||||
@@ -145,49 +128,36 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
public EmbeddingSearchResult<TextSegment> search(
|
||||
EmbeddingSearchRequest embeddingSearchRequest) {
|
||||
|
||||
SearchParam searchParam =
|
||||
buildSearchRequest(
|
||||
collectionName,
|
||||
embeddingSearchRequest.queryEmbedding().vectorAsList(),
|
||||
embeddingSearchRequest.filter(),
|
||||
embeddingSearchRequest.maxResults(),
|
||||
metricType,
|
||||
consistencyLevel);
|
||||
SearchParam searchParam = buildSearchRequest(collectionName,
|
||||
embeddingSearchRequest.queryEmbedding().vectorAsList(),
|
||||
embeddingSearchRequest.filter(), embeddingSearchRequest.maxResults(), metricType,
|
||||
consistencyLevel);
|
||||
|
||||
SearchResultsWrapper resultsWrapper =
|
||||
CollectionOperationsExecutor.search(milvusClient, searchParam);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> matches =
|
||||
toEmbeddingMatches(
|
||||
milvusClient,
|
||||
resultsWrapper,
|
||||
collectionName,
|
||||
consistencyLevel,
|
||||
retrieveEmbeddingsOnSearch);
|
||||
List<EmbeddingMatch<TextSegment>> matches = toEmbeddingMatches(milvusClient, resultsWrapper,
|
||||
collectionName, consistencyLevel, retrieveEmbeddingsOnSearch);
|
||||
|
||||
List<EmbeddingMatch<TextSegment>> result =
|
||||
matches.stream()
|
||||
.filter(match -> match.score() >= embeddingSearchRequest.minScore())
|
||||
matches.stream().filter(match -> match.score() >= embeddingSearchRequest.minScore())
|
||||
.collect(toList());
|
||||
|
||||
return new EmbeddingSearchResult<>(result);
|
||||
}
|
||||
|
||||
private void addInternal(String id, Embedding embedding, TextSegment textSegment) {
|
||||
addAllInternal(
|
||||
singletonList(id),
|
||||
singletonList(embedding),
|
||||
addAllInternal(singletonList(id), singletonList(embedding),
|
||||
textSegment == null ? null : singletonList(textSegment));
|
||||
}
|
||||
|
||||
private void addAllInternal(
|
||||
List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {
|
||||
private void addAllInternal(List<String> ids, List<Embedding> embeddings,
|
||||
List<TextSegment> textSegments) {
|
||||
List<InsertParam.Field> fields = new ArrayList<>();
|
||||
fields.add(new InsertParam.Field(ID_FIELD_NAME, ids));
|
||||
fields.add(new InsertParam.Field(TEXT_FIELD_NAME, toScalars(textSegments, ids.size())));
|
||||
fields.add(
|
||||
new InsertParam.Field(
|
||||
METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size())));
|
||||
fields.add(new InsertParam.Field(METADATA_FIELD_NAME,
|
||||
toMetadataJsons(textSegments, ids.size())));
|
||||
fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, toVectors(embeddings)));
|
||||
|
||||
insert(this.milvusClient, this.collectionName, fields);
|
||||
@@ -199,22 +169,22 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
/**
|
||||
* Removes a single embedding from the store by ID.
|
||||
*
|
||||
* <p>CAUTION
|
||||
* <p>
|
||||
* CAUTION
|
||||
*
|
||||
* <ul>
|
||||
* <li>Deleted entities can still be retrieved immediately after the deletion if the
|
||||
* consistency level is set lower than {@code Strong}
|
||||
* <li>Entities deleted beyond the pre-specified span of time for Time Travel cannot be
|
||||
* retrieved again.
|
||||
* <li>Frequent deletion operations will impact the system performance.
|
||||
* <li>Before deleting entities by comlpex boolean expressions, make sure the collection has
|
||||
* been loaded.
|
||||
* <li>Deleting entities by complex boolean expressions is not an atomic operation. Therefore,
|
||||
* if it fails halfway through, some data may still be deleted.
|
||||
* <li>Deleting entities by complex boolean expressions is supported only when the consistency
|
||||
* is set to Bounded. For details, <a
|
||||
* href="https://milvus.io/docs/v2.3.x/consistency.md#Consistency-levels">see
|
||||
* Consistency</a>
|
||||
* <li>Deleted entities can still be retrieved immediately after the deletion if the consistency
|
||||
* level is set lower than {@code Strong}
|
||||
* <li>Entities deleted beyond the pre-specified span of time for Time Travel cannot be
|
||||
* retrieved again.
|
||||
* <li>Frequent deletion operations will impact the system performance.
|
||||
* <li>Before deleting entities by comlpex boolean expressions, make sure the collection has
|
||||
* been loaded.
|
||||
* <li>Deleting entities by complex boolean expressions is not an atomic operation. Therefore,
|
||||
* if it fails halfway through, some data may still be deleted.
|
||||
* <li>Deleting entities by complex boolean expressions is supported only when the consistency
|
||||
* is set to Bounded. For details,
|
||||
* <a href="https://milvus.io/docs/v2.3.x/consistency.md#Consistency-levels">see Consistency</a>
|
||||
* </ul>
|
||||
*
|
||||
* @param ids A collection of unique IDs of the embeddings to be removed.
|
||||
@@ -223,36 +193,34 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
@Override
|
||||
public void removeAll(Collection<String> ids) {
|
||||
ensureNotEmpty(ids, "ids");
|
||||
removeForVector(
|
||||
this.milvusClient,
|
||||
this.collectionName,
|
||||
removeForVector(this.milvusClient, this.collectionName,
|
||||
format("%s in %s", ID_FIELD_NAME, formatValues(ids)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes all embeddings that match the specified {@link Filter} from the store.
|
||||
*
|
||||
* <p>CAUTION
|
||||
* <p>
|
||||
* CAUTION
|
||||
*
|
||||
* <ul>
|
||||
* <li>Deleted entities can still be retrieved immediately after the deletion if the
|
||||
* consistency level is set lower than {@code Strong}
|
||||
* <li>Entities deleted beyond the pre-specified span of time for Time Travel cannot be
|
||||
* retrieved again.
|
||||
* <li>Frequent deletion operations will impact the system performance.
|
||||
* <li>Before deleting entities by comlpex boolean expressions, make sure the collection has
|
||||
* been loaded.
|
||||
* <li>Deleting entities by complex boolean expressions is not an atomic operation. Therefore,
|
||||
* if it fails halfway through, some data may still be deleted.
|
||||
* <li>Deleting entities by complex boolean expressions is supported only when the consistency
|
||||
* is set to Bounded. For details, <a
|
||||
* href="https://milvus.io/docs/v2.3.x/consistency.md#Consistency-levels">see
|
||||
* Consistency</a>
|
||||
* <li>Deleted entities can still be retrieved immediately after the deletion if the consistency
|
||||
* level is set lower than {@code Strong}
|
||||
* <li>Entities deleted beyond the pre-specified span of time for Time Travel cannot be
|
||||
* retrieved again.
|
||||
* <li>Frequent deletion operations will impact the system performance.
|
||||
* <li>Before deleting entities by comlpex boolean expressions, make sure the collection has
|
||||
* been loaded.
|
||||
* <li>Deleting entities by complex boolean expressions is not an atomic operation. Therefore,
|
||||
* if it fails halfway through, some data may still be deleted.
|
||||
* <li>Deleting entities by complex boolean expressions is supported only when the consistency
|
||||
* is set to Bounded. For details,
|
||||
* <a href="https://milvus.io/docs/v2.3.x/consistency.md#Consistency-levels">see Consistency</a>
|
||||
* </ul>
|
||||
*
|
||||
* @param filter The filter to be applied to the {@link Metadata} of the {@link TextSegment}
|
||||
* during removal. Only embeddings whose {@code TextSegment}'s {@code Metadata} match the
|
||||
* {@code Filter} will be removed.
|
||||
* during removal. Only embeddings whose {@code TextSegment}'s {@code Metadata} match the
|
||||
* {@code Filter} will be removed.
|
||||
* @since Milvus version 2.3.x
|
||||
*/
|
||||
@Override
|
||||
@@ -264,30 +232,30 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
/**
|
||||
* Removes all embeddings from the store.
|
||||
*
|
||||
* <p>CAUTION
|
||||
* <p>
|
||||
* CAUTION
|
||||
*
|
||||
* <ul>
|
||||
* <li>Deleted entities can still be retrieved immediately after the deletion if the
|
||||
* consistency level is set lower than {@code Strong}
|
||||
* <li>Entities deleted beyond the pre-specified span of time for Time Travel cannot be
|
||||
* retrieved again.
|
||||
* <li>Frequent deletion operations will impact the system performance.
|
||||
* <li>Before deleting entities by comlpex boolean expressions, make sure the collection has
|
||||
* been loaded.
|
||||
* <li>Deleting entities by complex boolean expressions is not an atomic operation. Therefore,
|
||||
* if it fails halfway through, some data may still be deleted.
|
||||
* <li>Deleting entities by complex boolean expressions is supported only when the consistency
|
||||
* is set to Bounded. For details, <a
|
||||
* href="https://milvus.io/docs/v2.3.x/consistency.md#Consistency-levels">see
|
||||
* Consistency</a>
|
||||
* <li>Deleted entities can still be retrieved immediately after the deletion if the consistency
|
||||
* level is set lower than {@code Strong}
|
||||
* <li>Entities deleted beyond the pre-specified span of time for Time Travel cannot be
|
||||
* retrieved again.
|
||||
* <li>Frequent deletion operations will impact the system performance.
|
||||
* <li>Before deleting entities by comlpex boolean expressions, make sure the collection has
|
||||
* been loaded.
|
||||
* <li>Deleting entities by complex boolean expressions is not an atomic operation. Therefore,
|
||||
* if it fails halfway through, some data may still be deleted.
|
||||
* <li>Deleting entities by complex boolean expressions is supported only when the consistency
|
||||
* is set to Bounded. For details,
|
||||
* <a href="https://milvus.io/docs/v2.3.x/consistency.md#Consistency-levels">see Consistency</a>
|
||||
* </ul>
|
||||
*
|
||||
* @since Milvus version 2.3.x
|
||||
*/
|
||||
@Override
|
||||
public void removeAll() {
|
||||
removeForVector(
|
||||
this.milvusClient, this.collectionName, format("%s != \"\"", ID_FIELD_NAME));
|
||||
removeForVector(this.milvusClient, this.collectionName,
|
||||
format("%s != \"\"", ID_FIELD_NAME));
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
@@ -327,7 +295,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
||||
/**
|
||||
* @param collectionName The name of the Milvus collection. If there is no such collection
|
||||
* yet, it will be created automatically. Default value: "default".
|
||||
* yet, it will be created automatically. Default value: "default".
|
||||
* @return builder
|
||||
*/
|
||||
public Builder collectionName(String collectionName) {
|
||||
@@ -337,7 +305,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
||||
/**
|
||||
* @param dimension The dimension of the embedding vector. (e.g. 384) Mandatory if a new
|
||||
* collection should be created.
|
||||
* collection should be created.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder dimension(Integer dimension) {
|
||||
@@ -356,7 +324,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
||||
/**
|
||||
* @param metricType The type of the metric used for similarity search. Default value:
|
||||
* COSINE.
|
||||
* COSINE.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder metricType(MetricType metricType) {
|
||||
@@ -366,7 +334,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
||||
/**
|
||||
* @param uri The URI of the managed Milvus instance. (e.g.
|
||||
* "https://xxx.api.gcp-us-west1.zillizcloud.com")
|
||||
* "https://xxx.api.gcp-us-west1.zillizcloud.com")
|
||||
* @return builder
|
||||
*/
|
||||
public Builder uri(String uri) {
|
||||
@@ -384,8 +352,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
}
|
||||
|
||||
/**
|
||||
* @param username The username. See details <a
|
||||
* href="https://milvus.io/docs/authenticate.md">here</a>.
|
||||
* @param username The username. See details
|
||||
* <a href="https://milvus.io/docs/authenticate.md">here</a>.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder username(String username) {
|
||||
@@ -394,8 +362,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
}
|
||||
|
||||
/**
|
||||
* @param password The password. See details <a
|
||||
* href="https://milvus.io/docs/authenticate.md">here</a>.
|
||||
* @param password The password. See details
|
||||
* <a href="https://milvus.io/docs/authenticate.md">here</a>.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder password(String password) {
|
||||
@@ -414,10 +382,10 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
||||
/**
|
||||
* @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling
|
||||
* findRelevant()), the embedding itself is not retrieved. To retrieve the embedding, an
|
||||
* additional query is required. Setting this parameter to "true" will ensure that
|
||||
* embedding is retrieved. Be aware that this will impact the performance of the search.
|
||||
* Default value: false.
|
||||
* findRelevant()), the embedding itself is not retrieved. To retrieve the embedding,
|
||||
* an additional query is required. Setting this parameter to "true" will ensure that
|
||||
* embedding is retrieved. Be aware that this will impact the performance of the
|
||||
* search. Default value: false.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder retrieveEmbeddingsOnSearch(Boolean retrieveEmbeddingsOnSearch) {
|
||||
@@ -428,8 +396,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
/**
|
||||
* @param autoFlushOnInsert Whether to automatically flush after each insert ({@code
|
||||
* add(...)} or {@code addAll(...)} methods). Default value: false. More info can be
|
||||
* found <a
|
||||
* href="https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/flush.md">here</a>.
|
||||
* found <a href=
|
||||
* "https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/flush.md">here</a>.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder autoFlushOnInsert(Boolean autoFlushOnInsert) {
|
||||
@@ -439,7 +407,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
|
||||
/**
|
||||
* @param databaseName Milvus name of database. Default value: null. In this case default
|
||||
* Milvus database name will be used.
|
||||
* Milvus database name will be used.
|
||||
* @return builder
|
||||
*/
|
||||
public Builder databaseName(String databaseName) {
|
||||
@@ -448,21 +416,9 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
|
||||
}
|
||||
|
||||
public MilvusEmbeddingStore build() {
|
||||
return new MilvusEmbeddingStore(
|
||||
host,
|
||||
port,
|
||||
collectionName,
|
||||
dimension,
|
||||
indexType,
|
||||
metricType,
|
||||
uri,
|
||||
token,
|
||||
username,
|
||||
password,
|
||||
consistencyLevel,
|
||||
retrieveEmbeddingsOnSearch,
|
||||
autoFlushOnInsert,
|
||||
databaseName);
|
||||
return new MilvusEmbeddingStore(host, port, collectionName, dimension, indexType,
|
||||
metricType, uri, token, username, password, consistencyLevel,
|
||||
retrieveEmbeddingsOnSearch, autoFlushOnInsert, databaseName);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,9 +12,12 @@ public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.zhipu";
|
||||
|
||||
@NestedConfigurationProperty ChatModelProperties chatModel;
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties chatModel;
|
||||
|
||||
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties streamingChatModel;
|
||||
|
||||
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingModelProperties embeddingModel;
|
||||
}
|
||||
|
||||
@@ -18,46 +18,36 @@ public class ZhipuAutoConfig {
|
||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
||||
ZhipuAiChatModel zhipuAiChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
||||
return ZhipuAiChatModel.builder()
|
||||
.baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.model(chatModelProperties.getModelName())
|
||||
return ZhipuAiChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP())
|
||||
.maxRetries(chatModelProperties.getMaxRetries())
|
||||
.topP(chatModelProperties.getTopP()).maxRetries(chatModelProperties.getMaxRetries())
|
||||
.maxToken(chatModelProperties.getMaxToken())
|
||||
.logRequests(chatModelProperties.getLogRequests())
|
||||
.logResponses(chatModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
||||
ZhipuAiStreamingChatModel zhipuStreamingChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
||||
return ZhipuAiStreamingChatModel.builder()
|
||||
.baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.model(chatModelProperties.getModelName())
|
||||
return ZhipuAiStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP())
|
||||
.maxToken(chatModelProperties.getMaxToken())
|
||||
.topP(chatModelProperties.getTopP()).maxToken(chatModelProperties.getMaxToken())
|
||||
.logRequests(chatModelProperties.getLogRequests())
|
||||
.logResponses(chatModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
||||
ZhipuAiEmbeddingModel zhipuEmbeddingModel(Properties properties) {
|
||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
||||
return ZhipuAiEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelProperties.getBaseUrl())
|
||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
|
||||
.apiKey(embeddingModelProperties.getApiKey())
|
||||
.model(embeddingModelProperties.getModel())
|
||||
.maxRetries(embeddingModelProperties.getMaxRetries())
|
||||
.logRequests(embeddingModelProperties.getLogRequests())
|
||||
.logResponses(embeddingModelProperties.getLogResponses())
|
||||
.build();
|
||||
.logResponses(embeddingModelProperties.getLogResponses()).build();
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user