(improvement)[build] Use Spotless to customize the code formatting (#1750)

This commit is contained in:
lexluo09
2024-10-04 00:05:04 +08:00
committed by GitHub
parent 44d1cde34f
commit 71a9954be5
521 changed files with 7811 additions and 13046 deletions

View File

@@ -24,11 +24,8 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
@Override
public EmbeddingStore createEmbeddingStore(String collectionName) {
return ChromaEmbeddingStore.builder()
.baseUrl(storeProperties.getBaseUrl())
.collectionName(collectionName)
.timeout(storeProperties.getTimeout())
.build();
return ChromaEmbeddingStore.builder().baseUrl(storeProperties.getBaseUrl())
.collectionName(collectionName).timeout(storeProperties.getTimeout()).build();
}
private static EmbeddingStoreProperties createPropertiesFromConfig(

View File

@@ -12,5 +12,6 @@ public class Properties {
static final String PREFIX = "langchain4j.chroma";
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
@NestedConfigurationProperty
EmbeddingStoreProperties embeddingStore;
}

View File

@@ -20,18 +20,15 @@ public class DashscopeAutoConfig {
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
QwenChatModel qwenChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getChatModel();
return QwenChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
return QwenChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.modelName(chatModelProperties.getModelName())
.topP(chatModelProperties.getTopP())
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
.topK(chatModelProperties.getTopK())
.enableSearch(chatModelProperties.getEnableSearch())
.seed(chatModelProperties.getSeed())
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
.temperature(chatModelProperties.getTemperature())
.stops(chatModelProperties.getStops())
.maxTokens(chatModelProperties.getMaxTokens())
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
.build();
}
@@ -39,18 +36,15 @@ public class DashscopeAutoConfig {
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
QwenStreamingChatModel qwenStreamingChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
return QwenStreamingChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
return QwenStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.modelName(chatModelProperties.getModelName())
.topP(chatModelProperties.getTopP())
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
.topK(chatModelProperties.getTopK())
.enableSearch(chatModelProperties.getEnableSearch())
.seed(chatModelProperties.getSeed())
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
.temperature(chatModelProperties.getTemperature())
.stops(chatModelProperties.getStops())
.maxTokens(chatModelProperties.getMaxTokens())
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
.build();
}
@@ -58,47 +52,33 @@ public class DashscopeAutoConfig {
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
QwenLanguageModel qwenLanguageModel(Properties properties) {
ChatModelProperties languageModel = properties.getLanguageModel();
return QwenLanguageModel.builder()
.baseUrl(languageModel.getBaseUrl())
.apiKey(languageModel.getApiKey())
.modelName(languageModel.getModelName())
.topP(languageModel.getTopP())
.topK(languageModel.getTopK())
.enableSearch(languageModel.getEnableSearch())
.seed(languageModel.getSeed())
return QwenLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
.repetitionPenalty(languageModel.getRepetitionPenalty())
.temperature(languageModel.getTemperature())
.stops(languageModel.getStops())
.maxTokens(languageModel.getMaxTokens())
.build();
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
.maxTokens(languageModel.getMaxTokens()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
QwenStreamingLanguageModel qwenStreamingLanguageModel(Properties properties) {
ChatModelProperties languageModel = properties.getStreamingLanguageModel();
return QwenStreamingLanguageModel.builder()
.baseUrl(languageModel.getBaseUrl())
.apiKey(languageModel.getApiKey())
.modelName(languageModel.getModelName())
.topP(languageModel.getTopP())
.topK(languageModel.getTopK())
.enableSearch(languageModel.getEnableSearch())
.seed(languageModel.getSeed())
return QwenStreamingLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
.repetitionPenalty(languageModel.getRepetitionPenalty())
.temperature(languageModel.getTemperature())
.stops(languageModel.getStops())
.maxTokens(languageModel.getMaxTokens())
.build();
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
.maxTokens(languageModel.getMaxTokens()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
QwenEmbeddingModel qwenEmbeddingModel(Properties properties) {
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
return QwenEmbeddingModel.builder()
.apiKey(embeddingModelProperties.getApiKey())
.modelName(embeddingModelProperties.getModelName())
.build();
return QwenEmbeddingModel.builder().apiKey(embeddingModelProperties.getApiKey())
.modelName(embeddingModelProperties.getModelName()).build();
}
}

View File

@@ -12,13 +12,18 @@ public class Properties {
static final String PREFIX = "langchain4j.dashscope";
@NestedConfigurationProperty ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty ChatModelProperties languageModel;
@NestedConfigurationProperty
ChatModelProperties languageModel;
@NestedConfigurationProperty ChatModelProperties streamingLanguageModel;
@NestedConfigurationProperty
ChatModelProperties streamingLanguageModel;
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}

View File

@@ -74,8 +74,8 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
if (MapUtils.isEmpty(super.collectionNameToStore)) {
return;
}
for (Map.Entry<String, EmbeddingStore<TextSegment>> entry :
collectionNameToStore.entrySet()) {
for (Map.Entry<String, EmbeddingStore<TextSegment>> entry : collectionNameToStore
.entrySet()) {
Path filePath = getPersistPath(entry.getKey());
if (Objects.isNull(filePath)) {
continue;

View File

@@ -12,7 +12,9 @@ public class Properties {
static final String PREFIX = "langchain4j.in-memory";
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
@NestedConfigurationProperty
EmbeddingStoreProperties embeddingStore;
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}

View File

@@ -20,70 +20,58 @@ public class LocalAiAutoConfig {
@ConditionalOnProperty(PREFIX + ".chat-model.base-url")
LocalAiChatModel localAiChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getChatModel();
return LocalAiChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
return LocalAiChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.modelName(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.maxRetries(chatModelProperties.getMaxRetries())
.topP(chatModelProperties.getTopP()).maxRetries(chatModelProperties.getMaxRetries())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses())
.build();
.logResponses(chatModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.base-url")
LocalAiStreamingChatModel localAiStreamingChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
return LocalAiStreamingChatModel.builder()
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.baseUrl(chatModelProperties.getBaseUrl())
return LocalAiStreamingChatModel.builder().temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP()).baseUrl(chatModelProperties.getBaseUrl())
.modelName(chatModelProperties.getModelName())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses())
.build();
.logResponses(chatModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".language-model.base-url")
LocalAiLanguageModel localAiLanguageModel(Properties properties) {
LanguageModelProperties languageModelProperties = properties.getLanguageModel();
return LocalAiLanguageModel.builder()
.topP(languageModelProperties.getTopP())
return LocalAiLanguageModel.builder().topP(languageModelProperties.getTopP())
.baseUrl(languageModelProperties.getBaseUrl())
.modelName(languageModelProperties.getModelName())
.temperature(languageModelProperties.getTemperature())
.maxRetries(languageModelProperties.getMaxRetries())
.logRequests(languageModelProperties.getLogRequests())
.logResponses(languageModelProperties.getLogResponses())
.build();
.logResponses(languageModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-language-model.base-url")
LocalAiStreamingLanguageModel localAiStreamingLanguageModel(Properties properties) {
LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel();
return LocalAiStreamingLanguageModel.builder()
.topP(languageModelProperties.getTopP())
return LocalAiStreamingLanguageModel.builder().topP(languageModelProperties.getTopP())
.baseUrl(languageModelProperties.getBaseUrl())
.modelName(languageModelProperties.getModelName())
.temperature(languageModelProperties.getTemperature())
.logRequests(languageModelProperties.getLogRequests())
.logResponses(languageModelProperties.getLogResponses())
.build();
.logResponses(languageModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-model.base-url")
LocalAiEmbeddingModel localAiEmbeddingModel(Properties properties) {
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
return LocalAiEmbeddingModel.builder()
.baseUrl(embeddingModelProperties.getBaseUrl())
return LocalAiEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
.modelName(embeddingModelProperties.getModelName())
.maxRetries(embeddingModelProperties.getMaxRetries())
.logRequests(embeddingModelProperties.getLogRequests())
.logResponses(embeddingModelProperties.getLogResponses())
.build();
.logResponses(embeddingModelProperties.getLogResponses()).build();
}
}

View File

@@ -12,13 +12,18 @@ public class Properties {
static final String PREFIX = "langchain4j.local-ai";
@NestedConfigurationProperty ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty LanguageModelProperties languageModel;
@NestedConfigurationProperty
LanguageModelProperties languageModel;
@NestedConfigurationProperty LanguageModelProperties streamingLanguageModel;
@NestedConfigurationProperty
LanguageModelProperties streamingLanguageModel;
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}

View File

@@ -29,21 +29,15 @@ public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
@Override
public EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName) {
return MilvusEmbeddingStore.builder()
.host(storeProperties.getHost())
.port(storeProperties.getPort())
.collectionName(collectionName)
.dimension(storeProperties.getDimension())
.indexType(storeProperties.getIndexType())
.metricType(storeProperties.getMetricType())
.uri(storeProperties.getUri())
.token(storeProperties.getToken())
.username(storeProperties.getUsername())
return MilvusEmbeddingStore.builder().host(storeProperties.getHost())
.port(storeProperties.getPort()).collectionName(collectionName)
.dimension(storeProperties.getDimension()).indexType(storeProperties.getIndexType())
.metricType(storeProperties.getMetricType()).uri(storeProperties.getUri())
.token(storeProperties.getToken()).username(storeProperties.getUsername())
.password(storeProperties.getPassword())
.consistencyLevel(storeProperties.getConsistencyLevel())
.retrieveEmbeddingsOnSearch(storeProperties.getRetrieveEmbeddingsOnSearch())
.autoFlushOnInsert(storeProperties.getAutoFlushOnInsert())
.databaseName(storeProperties.getDatabaseName())
.build();
.databaseName(storeProperties.getDatabaseName()).build();
}
}

View File

@@ -12,5 +12,6 @@ public class Properties {
static final String PREFIX = "langchain4j.milvus";
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
@NestedConfigurationProperty
EmbeddingStoreProperties embeddingStore;
}

View File

@@ -13,10 +13,10 @@ import java.util.Objects;
/**
* An embedding model that runs within your Java application's process. Any BERT-based model (e.g.,
* from HuggingFace) can be used, as long as it is in ONNX format. Information on how to convert
* models into ONNX format can be found <a
* href="https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>.
* Many models already converted to ONNX format are available <a
* href="https://huggingface.co/Xenova">here</a>. Copy from
* models into ONNX format can be found <a href=
* "https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>. Many
* models already converted to ONNX format are available
* <a href="https://huggingface.co/Xenova">here</a>. Copy from
* dev.langchain4j.model.embedding.OnnxEmbeddingModel.
*/
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
@@ -28,9 +28,8 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
if (shouldReloadModel(pathToModel, vocabularyPath)) {
synchronized (S2OnnxEmbeddingModel.class) {
if (shouldReloadModel(pathToModel, vocabularyPath)) {
URL resource =
AbstractInProcessEmbeddingModel.class.getResource(
"/bert-vocabulary-en.txt");
URL resource = AbstractInProcessEmbeddingModel.class
.getResource("/bert-vocabulary-en.txt");
if (StringUtils.isNotBlank(vocabularyPath)) {
try {
resource = Paths.get(vocabularyPath).toUri().toURL();
@@ -56,15 +55,14 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
}
private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) {
return cachedModel == null
|| !Objects.equals(cachedModelPath, pathToModel)
return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel)
|| !Objects.equals(cachedVocabularyPath, vocabularyPath);
}
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
try {
return new OnnxBertBiEncoder(
Files.newInputStream(pathToModel), vocabularyFile, PoolingMode.MEAN);
return new OnnxBertBiEncoder(Files.newInputStream(pathToModel), vocabularyFile,
PoolingMode.MEAN);
} catch (IOException e) {
throw new RuntimeException(e);
}

View File

@@ -60,8 +60,8 @@ import static java.util.Collections.singletonList;
/**
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and
* gpt-4. You can find description of parameters <a
* href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
* gpt-4. You can find description of parameters
* <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*/
@Slf4j
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
@@ -88,32 +88,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final List<ChatModelListener> listeners;
@Builder
public OpenAiChatModel(
String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Double topP,
List<String> stop,
Integer maxTokens,
Double presencePenalty,
Double frequencyPenalty,
Map<String, Integer> logitBias,
String responseFormat,
Boolean strictJsonSchema,
Integer seed,
String user,
Boolean strictTools,
Boolean parallelToolCalls,
Duration timeout,
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders,
List<ChatModelListener> listeners) {
public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName,
Double temperature, Double topP, List<String> stop, Integer maxTokens,
Double presencePenalty, Double frequencyPenalty, Map<String, Integer> logitBias,
String responseFormat, Boolean strictJsonSchema, Integer seed, String user,
Boolean strictTools, Boolean parallelToolCalls, Duration timeout, Integer maxRetries,
Proxy proxy, Boolean logRequests, Boolean logResponses, Tokenizer tokenizer,
Map<String, String> customHeaders, List<ChatModelListener> listeners) {
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
@@ -123,21 +104,11 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
timeout = getOrDefault(timeout, ofSeconds(60));
this.client =
OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.proxy(proxy)
.logRequests(logRequests)
.logResponses(logResponses)
.userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders)
.build();
this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl)
.organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout)
.readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests)
.logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders).build();
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
@@ -146,14 +117,10 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
this.presencePenalty = presencePenalty;
this.frequencyPenalty = frequencyPenalty;
this.logitBias = logitBias;
this.responseFormat =
responseFormat == null
? null
: ResponseFormat.builder()
.type(
ResponseFormatType.valueOf(
responseFormat.toUpperCase(Locale.ROOT)))
.build();
this.responseFormat = responseFormat == null ? null
: ResponseFormat.builder()
.type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT)))
.build();
this.strictJsonSchema = getOrDefault(strictJsonSchema, false);
this.seed = seed;
this.user = user;
@@ -183,61 +150,44 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
public Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
return generate(messages, toolSpecifications, null, this.responseFormat);
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(
messages, singletonList(toolSpecification), toolSpecification, this.responseFormat);
public Response<AiMessage> generate(List<ChatMessage> messages,
ToolSpecification toolSpecification) {
return generate(messages, singletonList(toolSpecification), toolSpecification,
this.responseFormat);
}
@Override
public ChatResponse chat(ChatRequest request) {
Response<AiMessage> response =
generate(
request.messages(),
request.toolSpecifications(),
null,
generate(request.messages(), request.toolSpecifications(), null,
getOrDefault(
toOpenAiResponseFormat(request.responseFormat(), strictJsonSchema),
this.responseFormat));
return ChatResponse.builder()
.aiMessage(response.content())
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.build();
return ChatResponse.builder().aiMessage(response.content())
.tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).build();
}
private Response<AiMessage> generate(
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted,
private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications, ToolSpecification toolThatMustBeExecuted,
ResponseFormat responseFormat) {
if (responseFormat != null
&& responseFormat.type() == JSON_SCHEMA
if (responseFormat != null && responseFormat.type() == JSON_SCHEMA
&& responseFormat.jsonSchema() == null) {
responseFormat = null;
}
ChatCompletionRequest.Builder requestBuilder =
ChatCompletionRequest.builder()
.model(modelName)
.messages(toOpenAiMessages(messages))
.topP(topP)
.stop(stop)
.maxTokens(maxTokens)
.presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty)
.logitBias(logitBias)
.responseFormat(responseFormat)
.seed(seed)
.user(user)
.parallelToolCalls(parallelToolCalls);
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.model(modelName).messages(toOpenAiMessages(messages)).topP(topP).stop(stop)
.maxTokens(maxTokens).presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty).logitBias(logitBias)
.responseFormat(responseFormat).seed(seed).user(user)
.parallelToolCalls(parallelToolCalls);
if (!(baseUrl.contains(ZHIPU))) {
requestBuilder.temperature(temperature);
@@ -257,40 +207,33 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext =
new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(
listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
listeners.forEach(listener -> {
try {
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
try {
ChatCompletionResponse chatCompletionResponse =
withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
Response<AiMessage> response =
Response.from(
aiMessageFrom(chatCompletionResponse),
tokenUsageFrom(chatCompletionResponse.usage()),
finishReasonFrom(
chatCompletionResponse.choices().get(0).finishReason()));
Response<AiMessage> response = Response.from(aiMessageFrom(chatCompletionResponse),
tokenUsageFrom(chatCompletionResponse.usage()),
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason()));
ChatModelResponse modelListenerResponse =
createModelListenerResponse(
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
ChatModelResponseContext responseContext =
new ChatModelResponseContext(
modelListenerResponse, modelListenerRequest, attributes);
listeners.forEach(
listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
ChatModelResponse modelListenerResponse = createModelListenerResponse(
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
ChatModelResponseContext responseContext = new ChatModelResponseContext(
modelListenerResponse, modelListenerRequest, attributes);
listeners.forEach(listener -> {
try {
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
return response;
} catch (RuntimeException e) {
@@ -305,14 +248,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
ChatModelErrorContext errorContext =
new ChatModelErrorContext(error, modelListenerRequest, null, attributes);
listeners.forEach(
listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
listeners.forEach(listener -> {
try {
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
throw e;
}
@@ -328,8 +270,8 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
}
public static OpenAiChatModelBuilder builder() {
for (OpenAiChatModelBuilderFactory factory :
loadFactories(OpenAiChatModelBuilderFactory.class)) {
for (OpenAiChatModelBuilderFactory factory : loadFactories(
OpenAiChatModelBuilderFactory.class)) {
return factory.get();
}
return new OpenAiChatModelBuilder();

View File

@@ -3,9 +3,8 @@ package dev.langchain4j.model.openai;
public enum OpenAiChatModelName {
GPT_3_5_TURBO("gpt-3.5-turbo"), // alias
@Deprecated
GPT_3_5_TURBO_0613("gpt-3.5-turbo-0613"),
GPT_3_5_TURBO_1106("gpt-3.5-turbo-1106"),
GPT_3_5_TURBO_0125("gpt-3.5-turbo-0125"),
GPT_3_5_TURBO_0613("gpt-3.5-turbo-0613"), GPT_3_5_TURBO_1106(
"gpt-3.5-turbo-1106"), GPT_3_5_TURBO_0125("gpt-3.5-turbo-0125"),
GPT_3_5_TURBO_16K("gpt-3.5-turbo-16k"), // alias
@Deprecated
@@ -13,22 +12,18 @@ public enum OpenAiChatModelName {
GPT_4("gpt-4"), // alias
@Deprecated
GPT_4_0314("gpt-4-0314"),
GPT_4_0613("gpt-4-0613"),
GPT_4_0314("gpt-4-0314"), GPT_4_0613("gpt-4-0613"),
GPT_4_TURBO_PREVIEW("gpt-4-turbo-preview"), // alias
GPT_4_1106_PREVIEW("gpt-4-1106-preview"),
GPT_4_0125_PREVIEW("gpt-4-0125-preview"),
GPT_4_1106_PREVIEW("gpt-4-1106-preview"), GPT_4_0125_PREVIEW("gpt-4-0125-preview"),
GPT_4_32K("gpt-4-32k"), // alias
GPT_4_32K_0314("gpt-4-32k-0314"),
GPT_4_32K_0613("gpt-4-32k-0613"),
GPT_4_32K_0314("gpt-4-32k-0314"), GPT_4_32K_0613("gpt-4-32k-0613"),
@Deprecated
GPT_4_VISION_PREVIEW("gpt-4-vision-preview"),
GPT_4_O("gpt-4o"),
GPT_4_O_MINI("gpt-4o-mini");
GPT_4_O("gpt-4o"), GPT_4_O_MINI("gpt-4o-mini");
private final String stringValue;

View File

@@ -1,9 +1,7 @@
package dev.langchain4j.model.zhipu;
public enum ChatCompletionModel {
GLM_4("glm-4"),
GLM_3_TURBO("glm-3-turbo"),
CHATGLM_TURBO("chatglm_turbo");
GLM_4("glm-4"), GLM_3_TURBO("glm-3-turbo"), CHATGLM_TURBO("chatglm_turbo");
private final String value;

View File

@@ -27,8 +27,8 @@ import static java.util.Collections.singletonList;
/**
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and
* glm-4. You can find description of parameters <a
* href="https://open.bigmodel.cn/dev/api">here</a>.
* glm-4. You can find description of parameters
* <a href="https://open.bigmodel.cn/dev/api">here</a>.
*/
public class ZhipuAiChatModel implements ChatLanguageModel {
@@ -41,15 +41,8 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
private final ZhipuAiClient client;
@Builder
public ZhipuAiChatModel(
String baseUrl,
String apiKey,
Double temperature,
Double topP,
String model,
Integer maxRetries,
Integer maxToken,
Boolean logRequests,
public ZhipuAiChatModel(String baseUrl, String apiKey, Double temperature, Double topP,
String model, Integer maxRetries, Integer maxToken, Boolean logRequests,
Boolean logResponses) {
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
this.temperature = getOrDefault(temperature, 0.7);
@@ -57,18 +50,14 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
this.maxRetries = getOrDefault(maxRetries, 3);
this.maxToken = getOrDefault(maxToken, 512);
this.client =
ZhipuAiClient.builder()
.baseUrl(this.baseUrl)
.apiKey(apiKey)
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false))
.build();
this.client = ZhipuAiClient.builder().baseUrl(this.baseUrl).apiKey(apiKey)
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false)).build();
}
public static ZhipuAiChatModelBuilder builder() {
for (ZhipuAiChatModelBuilderFactory factories :
loadFactories(ZhipuAiChatModelBuilderFactory.class)) {
for (ZhipuAiChatModelBuilderFactory factories : loadFactories(
ZhipuAiChatModelBuilderFactory.class)) {
return factories.get();
}
return new ZhipuAiChatModelBuilder();
@@ -80,15 +69,13 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
public Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
ensureNotEmpty(messages, "messages");
ChatCompletionRequest.Builder requestBuilder =
ChatCompletionRequest.builder().model(this.model).maxTokens(maxToken).stream(false)
.topP(topP)
.toolChoice(AUTO)
.messages(toZhipuAiMessages(messages));
.topP(topP).toolChoice(AUTO).messages(toZhipuAiMessages(messages));
if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.tools(toTools(toolSpecifications));
@@ -96,17 +83,15 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
ChatCompletionResponse response =
withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
return Response.from(
aiMessageFrom(response),
tokenUsageFrom(response.getUsage()),
return Response.from(aiMessageFrom(response), tokenUsageFrom(response.getUsage()),
finishReasonFrom(response.getChoices().get(0).getFinishReason()));
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(
messages, toolSpecification != null ? singletonList(toolSpecification) : null);
public Response<AiMessage> generate(List<ChatMessage> messages,
ToolSpecification toolSpecification) {
return generate(messages,
toolSpecification != null ? singletonList(toolSpecification) : null);
}
public static class ZhipuAiChatModelBuilder {

View File

@@ -20,36 +20,27 @@ public class AzureModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
AzureOpenAiChatModel.Builder builder =
AzureOpenAiChatModel.builder()
.endpoint(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.deploymentName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.maxRetries(modelConfig.getMaxRetries())
.topP(modelConfig.getTopP())
.timeout(
Duration.ofSeconds(
modelConfig.getTimeOut() == null
? 0L
: modelConfig.getTimeOut()))
.logRequestsAndResponses(
modelConfig.getLogRequests() != null
&& modelConfig.getLogResponses());
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
.endpoint(modelConfig.getBaseUrl()).apiKey(modelConfig.getApiKey())
.deploymentName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature()).maxRetries(modelConfig.getMaxRetries())
.topP(modelConfig.getTopP())
.timeout(Duration.ofSeconds(
modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
.logRequestsAndResponses(
modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
return builder.build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
AzureOpenAiEmbeddingModel.Builder builder =
AzureOpenAiEmbeddingModel.builder()
.endpoint(embeddingModelConfig.getBaseUrl())
AzureOpenAiEmbeddingModel.builder().endpoint(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.deploymentName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequestsAndResponses(
embeddingModelConfig.getLogRequests() != null
&& embeddingModelConfig.getLogResponses());
.logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null
&& embeddingModelConfig.getLogResponses());
return builder.build();
}

View File

@@ -19,25 +19,17 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return QwenChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.modelName(modelConfig.getModelName())
.temperature(
modelConfig.getTemperature() == null
? 0L
: modelConfig.getTemperature().floatValue())
.topP(modelConfig.getTopP())
.enableSearch(modelConfig.getEnableSearch())
.build();
return QwenChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey()).modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature() == null ? 0L
: modelConfig.getTemperature().floatValue())
.topP(modelConfig.getTopP()).enableSearch(modelConfig.getEnableSearch()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return QwenEmbeddingModel.builder()
.apiKey(embeddingModelConfig.getApiKey())
.modelName(embeddingModelConfig.getModelName())
.build();
return QwenEmbeddingModel.builder().apiKey(embeddingModelConfig.getApiKey())
.modelName(embeddingModelConfig.getModelName()).build();
}
@Override

View File

@@ -19,27 +19,20 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return LocalAiChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.topP(modelConfig.getTopP())
return LocalAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName()).temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut())).topP(modelConfig.getTopP())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.maxRetries(modelConfig.getMaxRetries())
.logResponses(modelConfig.getLogResponses()).maxRetries(modelConfig.getMaxRetries())
.build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
return LocalAiEmbeddingModel.builder()
.baseUrl(embeddingModel.getBaseUrl())
.modelName(embeddingModel.getModelName())
.maxRetries(embeddingModel.getMaxRetries())
return LocalAiEmbeddingModel.builder().baseUrl(embeddingModel.getBaseUrl())
.modelName(embeddingModel.getModelName()).maxRetries(embeddingModel.getMaxRetries())
.logRequests(embeddingModel.getLogRequests())
.logResponses(embeddingModel.getLogResponses())
.build();
.logResponses(embeddingModel.getLogResponses()).build();
}
@Override

View File

@@ -25,8 +25,7 @@ public class ModelProvider {
}
public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) {
if (modelConfig == null
|| StringUtils.isBlank(modelConfig.getProvider())
if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider())
|| StringUtils.isBlank(modelConfig.getBaseUrl())) {
ChatModelParameterConfig parameterConfig =
ContextUtils.getBean(ChatModelParameterConfig.class);

View File

@@ -21,27 +21,20 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return OllamaChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
return OllamaChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName()).temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut())).topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return OllamaEmbeddingModel.builder()
.baseUrl(embeddingModelConfig.getBaseUrl())
return OllamaEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.modelName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses())
.build();
.logResponses(embeddingModelConfig.getLogResponses()).build();
}
@Override

View File

@@ -21,29 +21,22 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return OpenAiChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName())
.apiKey(modelConfig.keyDecrypt())
.temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP())
return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
.logResponses(modelConfig.getLogResponses()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
return OpenAiEmbeddingModel.builder()
.baseUrl(embeddingModel.getBaseUrl())
.apiKey(embeddingModel.getApiKey())
.modelName(embeddingModel.getModelName())
return OpenAiEmbeddingModel.builder().baseUrl(embeddingModel.getBaseUrl())
.apiKey(embeddingModel.getApiKey()).modelName(embeddingModel.getModelName())
.maxRetries(embeddingModel.getMaxRetries())
.logRequests(embeddingModel.getLogRequests())
.logResponses(embeddingModel.getLogResponses())
.build();
.logResponses(embeddingModel.getLogResponses()).build();
}
@Override

View File

@@ -21,31 +21,23 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return QianfanChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.secretKey(modelConfig.getSecretKey())
.endpoint(modelConfig.getEndpoint())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
return QianfanChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey()).secretKey(modelConfig.getSecretKey())
.endpoint(modelConfig.getEndpoint()).modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return QianfanEmbeddingModel.builder()
.baseUrl(embeddingModelConfig.getBaseUrl())
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.secretKey(embeddingModelConfig.getSecretKey())
.modelName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses())
.build();
.logResponses(embeddingModelConfig.getLogResponses()).build();
}
@Override

View File

@@ -19,28 +19,20 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return ZhipuAiChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.model(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
return ZhipuAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey()).model(modelConfig.getModelName())
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return ZhipuAiEmbeddingModel.builder()
.baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.model(embeddingModelConfig.getModelName())
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses())
.build();
.logResponses(embeddingModelConfig.getLogResponses()).build();
}
@Override

View File

@@ -12,13 +12,18 @@ public class Properties {
static final String PREFIX = "langchain4j.qianfan";
@NestedConfigurationProperty ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty LanguageModelProperties languageModel;
@NestedConfigurationProperty
LanguageModelProperties languageModel;
@NestedConfigurationProperty LanguageModelProperties streamingLanguageModel;
@NestedConfigurationProperty
LanguageModelProperties streamingLanguageModel;
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}

View File

@@ -20,8 +20,7 @@ public class QianfanAutoConfig {
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
QianfanChatModel qianfanChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getChatModel();
return QianfanChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
return QianfanChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.secretKey(chatModelProperties.getSecretKey())
.endpoint(chatModelProperties.getEndpoint())
@@ -32,38 +31,32 @@ public class QianfanAutoConfig {
.responseFormat(chatModelProperties.getResponseFormat())
.maxRetries(chatModelProperties.getMaxRetries())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses())
.build();
.logResponses(chatModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
QianfanStreamingChatModel qianfanStreamingChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
return QianfanStreamingChatModel.builder()
.endpoint(chatModelProperties.getEndpoint())
return QianfanStreamingChatModel.builder().endpoint(chatModelProperties.getEndpoint())
.penaltyScore(chatModelProperties.getPenaltyScore())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.baseUrl(chatModelProperties.getBaseUrl())
.topP(chatModelProperties.getTopP()).baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.secretKey(chatModelProperties.getSecretKey())
.modelName(chatModelProperties.getModelName())
.responseFormat(chatModelProperties.getResponseFormat())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses())
.build();
.logResponses(chatModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
QianfanLanguageModel qianfanLanguageModel(Properties properties) {
LanguageModelProperties languageModelProperties = properties.getLanguageModel();
return QianfanLanguageModel.builder()
.endpoint(languageModelProperties.getEndpoint())
return QianfanLanguageModel.builder().endpoint(languageModelProperties.getEndpoint())
.penaltyScore(languageModelProperties.getPenaltyScore())
.topK(languageModelProperties.getTopK())
.topP(languageModelProperties.getTopP())
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
.baseUrl(languageModelProperties.getBaseUrl())
.apiKey(languageModelProperties.getApiKey())
.secretKey(languageModelProperties.getSecretKey())
@@ -71,8 +64,7 @@ public class QianfanAutoConfig {
.temperature(languageModelProperties.getTemperature())
.maxRetries(languageModelProperties.getMaxRetries())
.logRequests(languageModelProperties.getLogRequests())
.logResponses(languageModelProperties.getLogResponses())
.build();
.logResponses(languageModelProperties.getLogResponses()).build();
}
@Bean
@@ -82,8 +74,7 @@ public class QianfanAutoConfig {
return QianfanStreamingLanguageModel.builder()
.endpoint(languageModelProperties.getEndpoint())
.penaltyScore(languageModelProperties.getPenaltyScore())
.topK(languageModelProperties.getTopK())
.topP(languageModelProperties.getTopP())
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
.baseUrl(languageModelProperties.getBaseUrl())
.apiKey(languageModelProperties.getApiKey())
.secretKey(languageModelProperties.getSecretKey())
@@ -91,16 +82,14 @@ public class QianfanAutoConfig {
.temperature(languageModelProperties.getTemperature())
.maxRetries(languageModelProperties.getMaxRetries())
.logRequests(languageModelProperties.getLogRequests())
.logResponses(languageModelProperties.getLogResponses())
.build();
.logResponses(languageModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
QianfanEmbeddingModel qianfanEmbeddingModel(Properties properties) {
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
return QianfanEmbeddingModel.builder()
.baseUrl(embeddingModelProperties.getBaseUrl())
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
.endpoint(embeddingModelProperties.getEndpoint())
.apiKey(embeddingModelProperties.getApiKey())
.secretKey(embeddingModelProperties.getSecretKey())
@@ -108,7 +97,6 @@ public class QianfanAutoConfig {
.user(embeddingModelProperties.getUser())
.maxRetries(embeddingModelProperties.getMaxRetries())
.logRequests(embeddingModelProperties.getLogRequests())
.logResponses(embeddingModelProperties.getLogResponses())
.build();
.logResponses(embeddingModelProperties.getLogResponses()).build();
}
}

View File

@@ -27,24 +27,19 @@ public class EmbeddingStoreFactoryProvider {
return ContextUtils.getBean(EmbeddingStoreFactory.class);
}
if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(
embeddingStoreConfig,
return factoryMap.computeIfAbsent(embeddingStoreConfig,
storeConfig -> new ChromaEmbeddingStoreFactory(storeConfig));
}
if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(
embeddingStoreConfig,
return factoryMap.computeIfAbsent(embeddingStoreConfig,
storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig));
}
if (EmbeddingStoreType.IN_MEMORY
.name()
if (EmbeddingStoreType.IN_MEMORY.name()
.equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(
embeddingStoreConfig,
return factoryMap.computeIfAbsent(embeddingStoreConfig,
storeConfig -> new InMemoryEmbeddingStoreFactory(storeConfig));
}
throw new RuntimeException(
"Unsupported EmbeddingStoreFactory provider: "
+ embeddingStoreConfig.getProvider());
throw new RuntimeException("Unsupported EmbeddingStoreFactory provider: "
+ embeddingStoreConfig.getProvider());
}
}

View File

@@ -1,7 +1,5 @@
package dev.langchain4j.store.embedding;
public enum EmbeddingStoreType {
IN_MEMORY,
MILVUS,
CHROMA
IN_MEMORY, MILVUS, CHROMA
}

View File

@@ -36,8 +36,7 @@ public class Retrieval {
}
Retrieval retrieval = (Retrieval) o;
return Double.compare(retrieval.similarity, similarity) == 0
&& Objects.equal(id, retrieval.id)
&& Objects.equal(query, retrieval.query)
&& Objects.equal(id, retrieval.id) && Objects.equal(query, retrieval.query)
&& Objects.equal(metadata, retrieval.metadata);
}

View File

@@ -17,20 +17,12 @@ public class TextSegmentConvert {
public static final String QUERY_ID = "queryId";
public static List<TextSegment> convertToEmbedding(List<DataItem> dataItems) {
return dataItems.stream()
.map(
dataItem -> {
Map meta =
JSONObject.parseObject(
JSONObject.toJSONString(dataItem), Map.class);
TextSegment textSegment =
TextSegment.from(dataItem.getName(), new Metadata(meta));
addQueryId(
textSegment,
dataItem.getId() + dataItem.getType().name().toLowerCase());
return textSegment;
})
.collect(Collectors.toList());
return dataItems.stream().map(dataItem -> {
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
TextSegment textSegment = TextSegment.from(dataItem.getName(), new Metadata(meta));
addQueryId(textSegment, dataItem.getId() + dataItem.getType().name().toLowerCase());
return textSegment;
}).collect(Collectors.toList());
}
public static void addQueryId(TextSegment textSegment, String queryId) {

View File

@@ -40,16 +40,19 @@ import static java.util.stream.Collectors.toList;
/**
* An {@link EmbeddingStore} that stores embeddings in memory.
*
* <p>Uses a brute force approach by iterating over all embeddings to find the best matches.
* <p>
* Uses a brute force approach by iterating over all embeddings to find the best matches.
*
* <p>This store can be persisted using the {@link #serializeToJson()} and {@link
* #serializeToFile(Path)} methods.
* <p>
* This store can be persisted using the {@link #serializeToJson()} and
* {@link #serializeToFile(Path)} methods.
*
* <p>It can also be recreated from JSON or a file using the {@link #fromJson(String)} and {@link
* #fromFile(Path)} methods.
* <p>
* It can also be recreated from JSON or a file using the {@link #fromJson(String)} and
* {@link #fromFile(Path)} methods.
*
* @param <Embedded> The class of the object that has been embedded. Typically, it is {@link
* dev.langchain4j.data.segment.TextSegment}.
* @param <Embedded> The class of the object that has been embedded. Typically, it is
* {@link dev.langchain4j.data.segment.TextSegment}.
*/
public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
@@ -88,10 +91,8 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
@Override
public List<String> addAll(List<Embedding> embeddings) {
List<Entry<Embedded>> newEntries =
embeddings.stream()
.map(embedding -> new Entry<Embedded>(randomUUID(), embedding))
.collect(toList());
List<Entry<Embedded>> newEntries = embeddings.stream()
.map(embedding -> new Entry<Embedded>(randomUUID(), embedding)).collect(toList());
return add(newEntries);
}
@@ -103,11 +104,9 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
"The list of embeddings and embedded must have the same size");
}
List<Entry<Embedded>> newEntries =
IntStream.range(0, embeddings.size())
.mapToObj(
i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i)))
.collect(toList());
List<Entry<Embedded>> newEntries = IntStream.range(0, embeddings.size())
.mapToObj(i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i)))
.collect(toList());
return add(newEntries);
}
@@ -123,16 +122,15 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
public void removeAll(Filter filter) {
ensureNotNull(filter, "filter");
entries.removeIf(
entry -> {
if (entry.embedded instanceof TextSegment) {
return filter.test(((TextSegment) entry.embedded).metadata());
} else if (entry.embedded == null) {
return false;
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
});
entries.removeIf(entry -> {
if (entry.embedded instanceof TextSegment) {
return filter.test(((TextSegment) entry.embedded).metadata());
} else if (entry.embedded == null) {
return false;
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
});
}
@Override
@@ -157,9 +155,8 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
}
}
double cosineSimilarity =
CosineSimilarity.between(
entry.embedding, embeddingSearchRequest.queryEmbedding());
double cosineSimilarity = CosineSimilarity.between(entry.embedding,
embeddingSearchRequest.queryEmbedding());
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= embeddingSearchRequest.minScore()) {
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));
@@ -247,8 +244,8 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
}
private static InMemoryEmbeddingStoreJsonCodec loadCodec() {
for (InMemoryEmbeddingStoreJsonCodecFactory factory :
loadFactories(InMemoryEmbeddingStoreJsonCodecFactory.class)) {
for (InMemoryEmbeddingStoreJsonCodecFactory factory : loadFactories(
InMemoryEmbeddingStoreJsonCodecFactory.class)) {
return factory.create();
}
return new GsonInMemoryEmbeddingStoreJsonCodec();

View File

@@ -58,27 +58,13 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
private final boolean retrieveEmbeddingsOnSearch;
private final boolean autoFlushOnInsert;
public MilvusEmbeddingStore(
String host,
Integer port,
String collectionName,
Integer dimension,
IndexType indexType,
MetricType metricType,
String uri,
String token,
String username,
String password,
ConsistencyLevelEnum consistencyLevel,
Boolean retrieveEmbeddingsOnSearch,
Boolean autoFlushOnInsert,
String databaseName) {
public MilvusEmbeddingStore(String host, Integer port, String collectionName, Integer dimension,
IndexType indexType, MetricType metricType, String uri, String token, String username,
String password, ConsistencyLevelEnum consistencyLevel,
Boolean retrieveEmbeddingsOnSearch, Boolean autoFlushOnInsert, String databaseName) {
ConnectParam.Builder connectBuilder =
ConnectParam.newBuilder()
.withHost(getOrDefault(host, "localhost"))
.withPort(getOrDefault(port, 19530))
.withUri(uri)
.withToken(token)
ConnectParam.newBuilder().withHost(getOrDefault(host, "localhost"))
.withPort(getOrDefault(port, 19530)).withUri(uri).withToken(token)
.withAuthorization(getOrDefault(username, ""), getOrDefault(password, ""));
if (databaseName != null) {
@@ -93,12 +79,9 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false);
if (!hasCollection(this.milvusClient, this.collectionName)) {
createCollection(
this.milvusClient, this.collectionName, ensureNotNull(dimension, "dimension"));
createIndex(
this.milvusClient,
this.collectionName,
getOrDefault(indexType, FLAT),
createCollection(this.milvusClient, this.collectionName,
ensureNotNull(dimension, "dimension"));
createIndex(this.milvusClient, this.collectionName, getOrDefault(indexType, FLAT),
this.metricType);
}
@@ -145,49 +128,36 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
public EmbeddingSearchResult<TextSegment> search(
EmbeddingSearchRequest embeddingSearchRequest) {
SearchParam searchParam =
buildSearchRequest(
collectionName,
embeddingSearchRequest.queryEmbedding().vectorAsList(),
embeddingSearchRequest.filter(),
embeddingSearchRequest.maxResults(),
metricType,
consistencyLevel);
SearchParam searchParam = buildSearchRequest(collectionName,
embeddingSearchRequest.queryEmbedding().vectorAsList(),
embeddingSearchRequest.filter(), embeddingSearchRequest.maxResults(), metricType,
consistencyLevel);
SearchResultsWrapper resultsWrapper =
CollectionOperationsExecutor.search(milvusClient, searchParam);
List<EmbeddingMatch<TextSegment>> matches =
toEmbeddingMatches(
milvusClient,
resultsWrapper,
collectionName,
consistencyLevel,
retrieveEmbeddingsOnSearch);
List<EmbeddingMatch<TextSegment>> matches = toEmbeddingMatches(milvusClient, resultsWrapper,
collectionName, consistencyLevel, retrieveEmbeddingsOnSearch);
List<EmbeddingMatch<TextSegment>> result =
matches.stream()
.filter(match -> match.score() >= embeddingSearchRequest.minScore())
matches.stream().filter(match -> match.score() >= embeddingSearchRequest.minScore())
.collect(toList());
return new EmbeddingSearchResult<>(result);
}
private void addInternal(String id, Embedding embedding, TextSegment textSegment) {
addAllInternal(
singletonList(id),
singletonList(embedding),
addAllInternal(singletonList(id), singletonList(embedding),
textSegment == null ? null : singletonList(textSegment));
}
private void addAllInternal(
List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {
private void addAllInternal(List<String> ids, List<Embedding> embeddings,
List<TextSegment> textSegments) {
List<InsertParam.Field> fields = new ArrayList<>();
fields.add(new InsertParam.Field(ID_FIELD_NAME, ids));
fields.add(new InsertParam.Field(TEXT_FIELD_NAME, toScalars(textSegments, ids.size())));
fields.add(
new InsertParam.Field(
METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size())));
fields.add(new InsertParam.Field(METADATA_FIELD_NAME,
toMetadataJsons(textSegments, ids.size())));
fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, toVectors(embeddings)));
insert(this.milvusClient, this.collectionName, fields);
@@ -199,22 +169,22 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
/**
* Removes a single embedding from the store by ID.
*
* <p>CAUTION
* <p>
* CAUTION
*
* <ul>
* <li>Deleted entities can still be retrieved immediately after the deletion if the
* consistency level is set lower than {@code Strong}
* <li>Entities deleted beyond the pre-specified span of time for Time Travel cannot be
* retrieved again.
* <li>Frequent deletion operations will impact the system performance.
* <li>Before deleting entities by comlpex boolean expressions, make sure the collection has
* been loaded.
* <li>Deleting entities by complex boolean expressions is not an atomic operation. Therefore,
* if it fails halfway through, some data may still be deleted.
* <li>Deleting entities by complex boolean expressions is supported only when the consistency
* is set to Bounded. For details, <a
* href="https://milvus.io/docs/v2.3.x/consistency.md#Consistency-levels">see
* Consistency</a>
* <li>Deleted entities can still be retrieved immediately after the deletion if the consistency
* level is set lower than {@code Strong}
* <li>Entities deleted beyond the pre-specified span of time for Time Travel cannot be
* retrieved again.
* <li>Frequent deletion operations will impact the system performance.
* <li>Before deleting entities by comlpex boolean expressions, make sure the collection has
* been loaded.
* <li>Deleting entities by complex boolean expressions is not an atomic operation. Therefore,
* if it fails halfway through, some data may still be deleted.
* <li>Deleting entities by complex boolean expressions is supported only when the consistency
* is set to Bounded. For details,
* <a href="https://milvus.io/docs/v2.3.x/consistency.md#Consistency-levels">see Consistency</a>
* </ul>
*
* @param ids A collection of unique IDs of the embeddings to be removed.
@@ -223,36 +193,34 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
@Override
public void removeAll(Collection<String> ids) {
ensureNotEmpty(ids, "ids");
removeForVector(
this.milvusClient,
this.collectionName,
removeForVector(this.milvusClient, this.collectionName,
format("%s in %s", ID_FIELD_NAME, formatValues(ids)));
}
/**
* Removes all embeddings that match the specified {@link Filter} from the store.
*
* <p>CAUTION
* <p>
* CAUTION
*
* <ul>
* <li>Deleted entities can still be retrieved immediately after the deletion if the
* consistency level is set lower than {@code Strong}
* <li>Entities deleted beyond the pre-specified span of time for Time Travel cannot be
* retrieved again.
* <li>Frequent deletion operations will impact the system performance.
* <li>Before deleting entities by comlpex boolean expressions, make sure the collection has
* been loaded.
* <li>Deleting entities by complex boolean expressions is not an atomic operation. Therefore,
* if it fails halfway through, some data may still be deleted.
* <li>Deleting entities by complex boolean expressions is supported only when the consistency
* is set to Bounded. For details, <a
* href="https://milvus.io/docs/v2.3.x/consistency.md#Consistency-levels">see
* Consistency</a>
* <li>Deleted entities can still be retrieved immediately after the deletion if the consistency
* level is set lower than {@code Strong}
* <li>Entities deleted beyond the pre-specified span of time for Time Travel cannot be
* retrieved again.
* <li>Frequent deletion operations will impact the system performance.
* <li>Before deleting entities by comlpex boolean expressions, make sure the collection has
* been loaded.
* <li>Deleting entities by complex boolean expressions is not an atomic operation. Therefore,
* if it fails halfway through, some data may still be deleted.
* <li>Deleting entities by complex boolean expressions is supported only when the consistency
* is set to Bounded. For details,
* <a href="https://milvus.io/docs/v2.3.x/consistency.md#Consistency-levels">see Consistency</a>
* </ul>
*
* @param filter The filter to be applied to the {@link Metadata} of the {@link TextSegment}
* during removal. Only embeddings whose {@code TextSegment}'s {@code Metadata} match the
* {@code Filter} will be removed.
* during removal. Only embeddings whose {@code TextSegment}'s {@code Metadata} match the
* {@code Filter} will be removed.
* @since Milvus version 2.3.x
*/
@Override
@@ -264,30 +232,30 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
/**
* Removes all embeddings from the store.
*
* <p>CAUTION
* <p>
* CAUTION
*
* <ul>
* <li>Deleted entities can still be retrieved immediately after the deletion if the
* consistency level is set lower than {@code Strong}
* <li>Entities deleted beyond the pre-specified span of time for Time Travel cannot be
* retrieved again.
* <li>Frequent deletion operations will impact the system performance.
* <li>Before deleting entities by comlpex boolean expressions, make sure the collection has
* been loaded.
* <li>Deleting entities by complex boolean expressions is not an atomic operation. Therefore,
* if it fails halfway through, some data may still be deleted.
* <li>Deleting entities by complex boolean expressions is supported only when the consistency
* is set to Bounded. For details, <a
* href="https://milvus.io/docs/v2.3.x/consistency.md#Consistency-levels">see
* Consistency</a>
* <li>Deleted entities can still be retrieved immediately after the deletion if the consistency
* level is set lower than {@code Strong}
* <li>Entities deleted beyond the pre-specified span of time for Time Travel cannot be
* retrieved again.
* <li>Frequent deletion operations will impact the system performance.
* <li>Before deleting entities by comlpex boolean expressions, make sure the collection has
* been loaded.
* <li>Deleting entities by complex boolean expressions is not an atomic operation. Therefore,
* if it fails halfway through, some data may still be deleted.
* <li>Deleting entities by complex boolean expressions is supported only when the consistency
* is set to Bounded. For details,
* <a href="https://milvus.io/docs/v2.3.x/consistency.md#Consistency-levels">see Consistency</a>
* </ul>
*
* @since Milvus version 2.3.x
*/
@Override
public void removeAll() {
removeForVector(
this.milvusClient, this.collectionName, format("%s != \"\"", ID_FIELD_NAME));
removeForVector(this.milvusClient, this.collectionName,
format("%s != \"\"", ID_FIELD_NAME));
}
public static class Builder {
@@ -327,7 +295,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
/**
* @param collectionName The name of the Milvus collection. If there is no such collection
* yet, it will be created automatically. Default value: "default".
* yet, it will be created automatically. Default value: "default".
* @return builder
*/
public Builder collectionName(String collectionName) {
@@ -337,7 +305,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
/**
* @param dimension The dimension of the embedding vector. (e.g. 384) Mandatory if a new
* collection should be created.
* collection should be created.
* @return builder
*/
public Builder dimension(Integer dimension) {
@@ -356,7 +324,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
/**
* @param metricType The type of the metric used for similarity search. Default value:
* COSINE.
* COSINE.
* @return builder
*/
public Builder metricType(MetricType metricType) {
@@ -366,7 +334,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
/**
* @param uri The URI of the managed Milvus instance. (e.g.
* "https://xxx.api.gcp-us-west1.zillizcloud.com")
* "https://xxx.api.gcp-us-west1.zillizcloud.com")
* @return builder
*/
public Builder uri(String uri) {
@@ -384,8 +352,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param username The username. See details <a
* href="https://milvus.io/docs/authenticate.md">here</a>.
* @param username The username. See details
* <a href="https://milvus.io/docs/authenticate.md">here</a>.
* @return builder
*/
public Builder username(String username) {
@@ -394,8 +362,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param password The password. See details <a
* href="https://milvus.io/docs/authenticate.md">here</a>.
* @param password The password. See details
* <a href="https://milvus.io/docs/authenticate.md">here</a>.
* @return builder
*/
public Builder password(String password) {
@@ -414,10 +382,10 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
/**
* @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling
* findRelevant()), the embedding itself is not retrieved. To retrieve the embedding, an
* additional query is required. Setting this parameter to "true" will ensure that
* embedding is retrieved. Be aware that this will impact the performance of the search.
* Default value: false.
* findRelevant()), the embedding itself is not retrieved. To retrieve the embedding,
* an additional query is required. Setting this parameter to "true" will ensure that
* embedding is retrieved. Be aware that this will impact the performance of the
* search. Default value: false.
* @return builder
*/
public Builder retrieveEmbeddingsOnSearch(Boolean retrieveEmbeddingsOnSearch) {
@@ -428,8 +396,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
/**
* @param autoFlushOnInsert Whether to automatically flush after each insert ({@code
* add(...)} or {@code addAll(...)} methods). Default value: false. More info can be
* found <a
* href="https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/flush.md">here</a>.
* found <a href=
* "https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/flush.md">here</a>.
* @return builder
*/
public Builder autoFlushOnInsert(Boolean autoFlushOnInsert) {
@@ -439,7 +407,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
/**
* @param databaseName Milvus name of database. Default value: null. In this case default
* Milvus database name will be used.
* Milvus database name will be used.
* @return builder
*/
public Builder databaseName(String databaseName) {
@@ -448,21 +416,9 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
public MilvusEmbeddingStore build() {
return new MilvusEmbeddingStore(
host,
port,
collectionName,
dimension,
indexType,
metricType,
uri,
token,
username,
password,
consistencyLevel,
retrieveEmbeddingsOnSearch,
autoFlushOnInsert,
databaseName);
return new MilvusEmbeddingStore(host, port, collectionName, dimension, indexType,
metricType, uri, token, username, password, consistencyLevel,
retrieveEmbeddingsOnSearch, autoFlushOnInsert, databaseName);
}
}
}

View File

@@ -12,9 +12,12 @@ public class Properties {
static final String PREFIX = "langchain4j.zhipu";
@NestedConfigurationProperty ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}

View File

@@ -18,46 +18,36 @@ public class ZhipuAutoConfig {
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
ZhipuAiChatModel zhipuAiChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getChatModel();
return ZhipuAiChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.model(chatModelProperties.getModelName())
return ZhipuAiChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.maxRetries(chatModelProperties.getMaxRetries())
.topP(chatModelProperties.getTopP()).maxRetries(chatModelProperties.getMaxRetries())
.maxToken(chatModelProperties.getMaxToken())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses())
.build();
.logResponses(chatModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
ZhipuAiStreamingChatModel zhipuStreamingChatModel(Properties properties) {
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
return ZhipuAiStreamingChatModel.builder()
.baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.model(chatModelProperties.getModelName())
return ZhipuAiStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.maxToken(chatModelProperties.getMaxToken())
.topP(chatModelProperties.getTopP()).maxToken(chatModelProperties.getMaxToken())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses())
.build();
.logResponses(chatModelProperties.getLogResponses()).build();
}
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
ZhipuAiEmbeddingModel zhipuEmbeddingModel(Properties properties) {
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
return ZhipuAiEmbeddingModel.builder()
.baseUrl(embeddingModelProperties.getBaseUrl())
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
.apiKey(embeddingModelProperties.getApiKey())
.model(embeddingModelProperties.getModel())
.maxRetries(embeddingModelProperties.getMaxRetries())
.logRequests(embeddingModelProperties.getLogRequests())
.logResponses(embeddingModelProperties.getLogResponses())
.build();
.logResponses(embeddingModelProperties.getLogResponses()).build();
}
}