(improvement)(build) Add spotless during the build process. (#1639)

This commit is contained in:
lexluo09
2024-09-07 00:36:17 +08:00
committed by GitHub
parent ee15a88b06
commit 5f59e89eea
986 changed files with 15609 additions and 12706 deletions

View File

@@ -1,6 +1,5 @@
package dev.langchain4j.chroma.spring;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
@@ -18,4 +17,4 @@ public class ChromaAutoConfig {
EmbeddingStoreFactory chromaChatModel(Properties properties) {
return new ChromaEmbeddingStoreFactory(properties.getEmbeddingStore());
}
}
}

View File

@@ -31,10 +31,11 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
.build();
}
private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
private static EmbeddingStoreProperties createPropertiesFromConfig(
EmbeddingStoreConfig storeConfig) {
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
BeanUtils.copyProperties(storeConfig, embeddingStore);
embeddingStore.setTimeout(Duration.ofSeconds(storeConfig.getTimeOut()));
return embeddingStore;
}
}
}

View File

@@ -13,4 +13,4 @@ public class EmbeddingStoreProperties {
private String baseUrl;
private Duration timeout;
}
}

View File

@@ -12,6 +12,5 @@ public class Properties {
static final String PREFIX = "langchain4j.chroma";
@NestedConfigurationProperty
EmbeddingStoreProperties embeddingStore;
}
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
}

View File

@@ -1,9 +1,10 @@
package dev.langchain4j.dashscope.spring;
import java.util.List;
import lombok.Getter;
import lombok.Setter;
import java.util.List;
@Getter
@Setter
class ChatModelProperties {
@@ -19,4 +20,4 @@ class ChatModelProperties {
Float temperature;
List<String> stops;
Integer maxTokens;
}
}

View File

@@ -1,8 +1,5 @@
package dev.langchain4j.dashscope.spring;
import static dev.langchain4j.dashscope.spring.Properties.PREFIX;
import dev.langchain4j.model.dashscope.QwenChatModel;
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
import dev.langchain4j.model.dashscope.QwenLanguageModel;
@@ -13,6 +10,8 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import static dev.langchain4j.dashscope.spring.Properties.PREFIX;
@Configuration
@EnableConfigurationProperties(Properties.class)
public class DashscopeAutoConfig {
@@ -102,4 +101,4 @@ public class DashscopeAutoConfig {
.modelName(embeddingModelProperties.getModelName())
.build();
}
}
}

View File

@@ -9,4 +9,4 @@ class EmbeddingModelProperties {
private String apiKey;
private String modelName;
}
}

View File

@@ -12,18 +12,13 @@ public class Properties {
static final String PREFIX = "langchain4j.dashscope";
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
ChatModelProperties languageModel;
@NestedConfigurationProperty ChatModelProperties languageModel;
@NestedConfigurationProperty
ChatModelProperties streamingLanguageModel;
@NestedConfigurationProperty ChatModelProperties streamingLanguageModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
}

View File

@@ -10,4 +10,4 @@ class EmbeddingModelProperties {
private String modelName;
private String modelPath;
private String vocabularyPath;
}
}

View File

@@ -8,4 +8,4 @@ import lombok.Setter;
public class EmbeddingStoreProperties {
private String persistPath;
}
}

View File

@@ -1,6 +1,5 @@
package dev.langchain4j.inmemory.spring;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
@@ -42,4 +41,4 @@ public class InMemoryAutoConfig {
}
return new BgeSmallZhEmbeddingModel();
}
}
}

View File

@@ -33,7 +33,8 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
this.embeddingStore = embeddingStore;
}
private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
private static EmbeddingStoreProperties createPropertiesFromConfig(
EmbeddingStoreConfig storeConfig) {
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
BeanUtils.copyProperties(storeConfig, embeddingStore);
return embeddingStore;
@@ -56,7 +57,8 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
InMemoryEmbeddingStore<TextSegment> embeddingStore = null;
try {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (Files.exists(filePath) && !collectionName.equals(embeddingConfig.getMetaCollectionName())
if (Files.exists(filePath)
&& !collectionName.equals(embeddingConfig.getMetaCollectionName())
&& !collectionName.equals(embeddingConfig.getText2sqlCollectionName())) {
embeddingStore = InMemoryEmbeddingStore.fromFile(filePath);
embeddingStore.entries = new CopyOnWriteArraySet<>(embeddingStore.entries);
@@ -72,7 +74,8 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
if (MapUtils.isEmpty(super.collectionNameToStore)) {
return;
}
for (Map.Entry<String, EmbeddingStore<TextSegment>> entry : collectionNameToStore.entrySet()) {
for (Map.Entry<String, EmbeddingStore<TextSegment>> entry :
collectionNameToStore.entrySet()) {
Path filePath = getPersistPath(entry.getKey());
if (Objects.isNull(filePath)) {
continue;
@@ -101,5 +104,4 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
}
return Paths.get(persistPath, persistFile);
}
}
}

View File

@@ -12,9 +12,7 @@ public class Properties {
static final String PREFIX = "langchain4j.in-memory";
@NestedConfigurationProperty
EmbeddingStoreProperties embeddingStore;
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
}

View File

@@ -18,4 +18,4 @@ class ChatModelProperties {
private Double penaltyScore;
private Boolean logRequests;
private Boolean logResponses;
}
}

View File

@@ -15,5 +15,4 @@ class EmbeddingModelProperties {
private String user;
private Boolean logRequests;
private Boolean logResponses;
}
}

View File

@@ -18,5 +18,4 @@ class LanguageModelProperties {
private Double penaltyScore;
private Boolean logRequests;
private Boolean logResponses;
}
}

View File

@@ -86,4 +86,4 @@ public class LocalAiAutoConfig {
.logResponses(embeddingModelProperties.getLogResponses())
.build();
}
}
}

View File

@@ -12,18 +12,13 @@ public class Properties {
static final String PREFIX = "langchain4j.local-ai";
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
LanguageModelProperties languageModel;
@NestedConfigurationProperty LanguageModelProperties languageModel;
@NestedConfigurationProperty
LanguageModelProperties streamingLanguageModel;
@NestedConfigurationProperty LanguageModelProperties streamingLanguageModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
}

View File

@@ -24,4 +24,4 @@ class EmbeddingStoreProperties {
private Boolean retrieveEmbeddingsOnSearch;
private String databaseName;
private Boolean autoFlushOnInsert;
}
}

View File

@@ -17,4 +17,4 @@ public class MilvusAutoConfig {
EmbeddingStoreFactory milvusChatModel(Properties properties) {
return new MilvusEmbeddingStoreFactory(properties.getEmbeddingStore());
}
}
}

View File

@@ -18,7 +18,8 @@ public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
this.storeProperties = storeProperties;
}
private static EmbeddingStoreProperties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) {
private static EmbeddingStoreProperties createPropertiesFromConfig(
EmbeddingStoreConfig storeConfig) {
EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties();
BeanUtils.copyProperties(storeConfig, embeddingStore);
embeddingStore.setUri(storeConfig.getBaseUrl());
@@ -45,4 +46,4 @@ public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory {
.databaseName(storeProperties.getDatabaseName())
.build();
}
}
}

View File

@@ -12,6 +12,5 @@ public class Properties {
static final String PREFIX = "langchain4j.milvus";
@NestedConfigurationProperty
EmbeddingStoreProperties embeddingStore;
}
@NestedConfigurationProperty EmbeddingStoreProperties embeddingStore;
}

View File

@@ -11,12 +11,13 @@ import java.nio.file.Paths;
import java.util.Objects;
/**
* An embedding model that runs within your Java application's process.
* Any BERT-based model (e.g., from HuggingFace) can be used, as long as it is in ONNX format.
* Information on how to convert models into ONNX format can be found <a
* An embedding model that runs within your Java application's process. Any BERT-based model (e.g.,
* from HuggingFace) can be used, as long as it is in ONNX format. Information on how to convert
* models into ONNX format can be found <a
* href="https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model">here</a>.
* Many models already converted to ONNX format are available <a href="https://huggingface.co/Xenova">here</a>.
* Copy from dev.langchain4j.model.embedding.OnnxEmbeddingModel.
* Many models already converted to ONNX format are available <a
* href="https://huggingface.co/Xenova">here</a>. Copy from
* dev.langchain4j.model.embedding.OnnxEmbeddingModel.
*/
public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
private static volatile OnnxBertBiEncoder cachedModel;
@@ -27,7 +28,9 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
if (shouldReloadModel(pathToModel, vocabularyPath)) {
synchronized (S2OnnxEmbeddingModel.class) {
if (shouldReloadModel(pathToModel, vocabularyPath)) {
URL resource = AbstractInProcessEmbeddingModel.class.getResource("/bert-vocabulary-en.txt");
URL resource =
AbstractInProcessEmbeddingModel.class.getResource(
"/bert-vocabulary-en.txt");
if (StringUtils.isNotBlank(vocabularyPath)) {
try {
resource = Paths.get(vocabularyPath).toUri().toURL();
@@ -53,19 +56,17 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel {
}
private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) {
return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel)
return cachedModel == null
|| !Objects.equals(cachedModelPath, pathToModel)
|| !Objects.equals(cachedVocabularyPath, vocabularyPath);
}
static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) {
try {
return new OnnxBertBiEncoder(
Files.newInputStream(pathToModel),
vocabularyFile,
PoolingMode.MEAN
);
Files.newInputStream(pathToModel), vocabularyFile, PoolingMode.MEAN);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
}

View File

@@ -44,8 +44,9 @@ import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
/**
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and gpt-4.
* You can find description of parameters <a href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
* Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and
* gpt-4. You can find description of parameters <a
* href="https://platform.openai.com/docs/api-reference/chat/create">here</a>.
*/
@Slf4j
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
@@ -67,31 +68,33 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final Integer maxRetries;
private final Tokenizer tokenizer;
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners;
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>>
listeners;
@Builder
public OpenAiChatModel(String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Double topP,
List<String> stop,
Integer maxTokens,
Double presencePenalty,
Double frequencyPenalty,
Map<String, Integer> logitBias,
String responseFormat,
Integer seed,
String user,
Duration timeout,
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders,
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
public OpenAiChatModel(
String baseUrl,
String apiKey,
String organizationId,
String modelName,
Double temperature,
Double topP,
List<String> stop,
Integer maxTokens,
Double presencePenalty,
Double frequencyPenalty,
Map<String, Integer> logitBias,
String responseFormat,
Integer seed,
String user,
Duration timeout,
Integer maxRetries,
Proxy proxy,
Boolean logRequests,
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders,
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
@@ -101,20 +104,21 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
timeout = getOrDefault(timeout, ofSeconds(60));
this.client = OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.proxy(proxy)
.logRequests(logRequests)
.logResponses(logResponses)
.userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders)
.build();
this.client =
OpenAiClient.builder()
.openAiApiKey(apiKey)
.baseUrl(baseUrl)
.organizationId(organizationId)
.callTimeout(timeout)
.connectTimeout(timeout)
.readTimeout(timeout)
.writeTimeout(timeout)
.proxy(proxy)
.logRequests(logRequests)
.logResponses(logResponses)
.userAgent(DEFAULT_USER_AGENT)
.customHeaders(customHeaders)
.build();
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
@@ -141,31 +145,34 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
public Response<AiMessage> generate(
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
return generate(messages, toolSpecifications, null);
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
public Response<AiMessage> generate(
List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(messages, singletonList(toolSpecification), toolSpecification);
}
private Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted
) {
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.model(modelName)
.messages(toOpenAiMessages(messages))
.topP(topP)
.stop(stop)
.maxTokens(maxTokens)
.presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty)
.logitBias(logitBias)
.responseFormat(responseFormat)
.seed(seed)
.user(user);
private Response<AiMessage> generate(
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {
ChatCompletionRequest.Builder requestBuilder =
ChatCompletionRequest.builder()
.model(modelName)
.messages(toOpenAiMessages(messages))
.topP(topP)
.stop(stop)
.maxTokens(maxTokens)
.presencePenalty(presencePenalty)
.frequencyPenalty(frequencyPenalty)
.logitBias(logitBias)
.responseFormat(responseFormat)
.seed(seed)
.user(user);
if (!(baseUrl.contains(ZHIPU))) {
requestBuilder.temperature(temperature);
}
@@ -181,36 +188,37 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
ChatLanguageModelRequest modelListenerRequest =
createModelListenerRequest(request, messages, toolSpecifications);
listeners.forEach(listener -> {
try {
listener.onRequest(modelListenerRequest);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
listeners.forEach(
listener -> {
try {
listener.onRequest(modelListenerRequest);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
try {
ChatCompletionResponse chatCompletionResponse =
withRetry(() -> client.chatCompletion(request).execute(), maxRetries);
Response<AiMessage> response = Response.from(
aiMessageFrom(chatCompletionResponse),
tokenUsageFrom(chatCompletionResponse.usage()),
finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason())
);
Response<AiMessage> response =
Response.from(
aiMessageFrom(chatCompletionResponse),
tokenUsageFrom(chatCompletionResponse.usage()),
finishReasonFrom(
chatCompletionResponse.choices().get(0).finishReason()));
ChatLanguageModelResponse modelListenerResponse = createModelListenerResponse(
chatCompletionResponse.id(),
chatCompletionResponse.model(),
response
);
listeners.forEach(listener -> {
try {
listener.onResponse(modelListenerResponse, modelListenerRequest);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
ChatLanguageModelResponse modelListenerResponse =
createModelListenerResponse(
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
listeners.forEach(
listener -> {
try {
listener.onResponse(modelListenerResponse, modelListenerRequest);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
});
return response;
} catch (RuntimeException e) {
@@ -222,13 +230,14 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
error = e;
}
listeners.forEach(listener -> {
try {
listener.onError(error, null, modelListenerRequest);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
listeners.forEach(
listener -> {
try {
listener.onError(error, null, modelListenerRequest);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
throw e;
}
}
@@ -243,7 +252,8 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
}
public static OpenAiChatModelBuilder builder() {
for (OpenAiChatModelBuilderFactory factory : loadFactories(OpenAiChatModelBuilderFactory.class)) {
for (OpenAiChatModelBuilderFactory factory :
loadFactories(OpenAiChatModelBuilderFactory.class)) {
return factory.get();
}
return new OpenAiChatModelBuilder();
@@ -261,4 +271,4 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
return this;
}
}
}
}

View File

@@ -15,4 +15,4 @@ public enum ChatCompletionModel {
public String toString() {
return this.value;
}
}
}

View File

@@ -26,8 +26,9 @@ import static dev.langchain4j.spi.ServiceHelper.loadFactories;
import static java.util.Collections.singletonList;
/**
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and glm-4.
* You can find description of parameters <a href="https://open.bigmodel.cn/dev/api">here</a>.
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and
* glm-4. You can find description of parameters <a
* href="https://open.bigmodel.cn/dev/api">here</a>.
*/
public class ZhipuAiChatModel implements ChatLanguageModel {
@@ -49,24 +50,25 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
Integer maxRetries,
Integer maxToken,
Boolean logRequests,
Boolean logResponses
) {
Boolean logResponses) {
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
this.temperature = getOrDefault(temperature, 0.7);
this.topP = topP;
this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
this.maxRetries = getOrDefault(maxRetries, 3);
this.maxToken = getOrDefault(maxToken, 512);
this.client = ZhipuAiClient.builder()
.baseUrl(this.baseUrl)
.apiKey(apiKey)
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false))
.build();
this.client =
ZhipuAiClient.builder()
.baseUrl(this.baseUrl)
.apiKey(apiKey)
.logRequests(getOrDefault(logRequests, false))
.logResponses(getOrDefault(logResponses, false))
.build();
}
public static ZhipuAiChatModelBuilder builder() {
for (ZhipuAiChatModelBuilderFactory factories : loadFactories(ZhipuAiChatModelBuilderFactory.class)) {
for (ZhipuAiChatModelBuilderFactory factories :
loadFactories(ZhipuAiChatModelBuilderFactory.class)) {
return factories.get();
}
return new ZhipuAiChatModelBuilder();
@@ -78,36 +80,36 @@ public class ZhipuAiChatModel implements ChatLanguageModel {
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
public Response<AiMessage> generate(
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
ensureNotEmpty(messages, "messages");
ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder()
.model(this.model)
.maxTokens(maxToken)
.stream(false)
.topP(topP)
.toolChoice(AUTO)
.messages(toZhipuAiMessages(messages));
ChatCompletionRequest.Builder requestBuilder =
ChatCompletionRequest.builder().model(this.model).maxTokens(maxToken).stream(false)
.topP(topP)
.toolChoice(AUTO)
.messages(toZhipuAiMessages(messages));
if (!isNullOrEmpty(toolSpecifications)) {
requestBuilder.tools(toTools(toolSpecifications));
}
ChatCompletionResponse response = withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
ChatCompletionResponse response =
withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
return Response.from(
aiMessageFrom(response),
tokenUsageFrom(response.getUsage()),
finishReasonFrom(response.getChoices().get(0).getFinishReason())
);
finishReasonFrom(response.getChoices().get(0).getFinishReason()));
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(messages, toolSpecification != null ? singletonList(toolSpecification) : null);
public Response<AiMessage> generate(
List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(
messages, toolSpecification != null ? singletonList(toolSpecification) : null);
}
public static class ZhipuAiChatModelBuilder {
public ZhipuAiChatModelBuilder() {
}
public ZhipuAiChatModelBuilder() {}
}
}

View File

@@ -20,27 +20,36 @@ public class AzureModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
.endpoint(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.deploymentName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.maxRetries(modelConfig.getMaxRetries())
.topP(modelConfig.getTopP())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
.logRequestsAndResponses(modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
AzureOpenAiChatModel.Builder builder =
AzureOpenAiChatModel.builder()
.endpoint(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.deploymentName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.maxRetries(modelConfig.getMaxRetries())
.topP(modelConfig.getTopP())
.timeout(
Duration.ofSeconds(
modelConfig.getTimeOut() == null
? 0L
: modelConfig.getTimeOut()))
.logRequestsAndResponses(
modelConfig.getLogRequests() != null
&& modelConfig.getLogResponses());
return builder.build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
AzureOpenAiEmbeddingModel.Builder builder = AzureOpenAiEmbeddingModel.builder()
.endpoint(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.deploymentName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null
&& embeddingModelConfig.getLogResponses());
AzureOpenAiEmbeddingModel.Builder builder =
AzureOpenAiEmbeddingModel.builder()
.endpoint(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.deploymentName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequestsAndResponses(
embeddingModelConfig.getLogRequests() != null
&& embeddingModelConfig.getLogResponses());
return builder.build();
}

View File

@@ -23,8 +23,10 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
.baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature() == null ? 0L :
modelConfig.getTemperature().floatValue())
.temperature(
modelConfig.getTemperature() == null
? 0L
: modelConfig.getTemperature().floatValue())
.topP(modelConfig.getTopP())
.enableSearch(modelConfig.getEnableSearch())
.build();

View File

@@ -11,6 +11,6 @@ public class EmbeddingModelConstant {
public static final String BGE_SMALL_ZH = "bge-small-zh";
public static final String ALL_MINILM_L6_V2 = "all-minilm-l6-v2-q";
public static final EmbeddingModel BGE_SMALL_ZH_MODEL = new BgeSmallZhEmbeddingModel();
public static final EmbeddingModel ALL_MINI_LM_L6_V2_MODEL = new AllMiniLmL6V2QuantizedEmbeddingModel();
public static final EmbeddingModel ALL_MINI_LM_L6_V2_MODEL =
new AllMiniLmL6V2QuantizedEmbeddingModel();
}

View File

@@ -16,10 +16,10 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "LOCAL_AI";
public static final String DEFAULT_BASE_URL = "http://localhost:8080";
public static final String DEFAULT_MODEL_NAME = "ggml-gpt4all-j";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return LocalAiChatModel
.builder()
return LocalAiChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
@@ -46,4 +46,4 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
public void afterPropertiesSet() {
ModelProvider.add(PROVIDER, this);
}
}
}

View File

@@ -25,10 +25,11 @@ public class ModelProvider {
}
public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) {
if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider())
if (modelConfig == null
|| StringUtils.isBlank(modelConfig.getProvider())
|| StringUtils.isBlank(modelConfig.getBaseUrl())) {
ChatModelParameterConfig parameterConfig = ContextUtils.getBean(
ChatModelParameterConfig.class);
ChatModelParameterConfig parameterConfig =
ContextUtils.getBean(ChatModelParameterConfig.class);
modelConfig = parameterConfig.convert();
}
ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase());
@@ -36,7 +37,8 @@ public class ModelProvider {
return modelFactory.createChatModel(modelConfig);
}
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
throw new RuntimeException(
"Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
}
public static EmbeddingModel getEmbeddingModel() {
@@ -45,8 +47,8 @@ public class ModelProvider {
public static EmbeddingModel getEmbeddingModel(EmbeddingModelConfig embeddingModel) {
if (embeddingModel == null || StringUtils.isBlank(embeddingModel.getProvider())) {
EmbeddingModelParameterConfig parameterConfig = ContextUtils.getBean(
EmbeddingModelParameterConfig.class);
EmbeddingModelParameterConfig parameterConfig =
ContextUtils.getBean(EmbeddingModelParameterConfig.class);
embeddingModel = parameterConfig.convert();
}
ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase());
@@ -54,6 +56,7 @@ public class ModelProvider {
return modelFactory.createEmbeddingModel(embeddingModel);
}
throw new RuntimeException("Unsupported EmbeddingModel provider: " + embeddingModel.getProvider());
throw new RuntimeException(
"Unsupported EmbeddingModel provider: " + embeddingModel.getProvider());
}
}
}

View File

@@ -21,8 +21,7 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return OllamaChatModel
.builder()
return OllamaChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())

View File

@@ -21,8 +21,7 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return OpenAiChatModel
.builder()
return OpenAiChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName())
.apiKey(modelConfig.keyDecrypt())

View File

@@ -16,6 +16,7 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean {
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/";
public static final String DEFAULT_MODEL_NAME = ChatCompletionModel.GLM_4.toString();
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "embedding-2";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return ZhipuAiChatModel.builder()

View File

@@ -18,4 +18,4 @@ class ChatModelProperties {
private Double penaltyScore;
private Boolean logRequests;
private Boolean logResponses;
}
}

View File

@@ -15,5 +15,4 @@ class EmbeddingModelProperties {
private String user;
private Boolean logRequests;
private Boolean logResponses;
}
}

View File

@@ -18,5 +18,4 @@ class LanguageModelProperties {
private Double penaltyScore;
private Boolean logRequests;
private Boolean logResponses;
}
}

View File

@@ -12,18 +12,13 @@ public class Properties {
static final String PREFIX = "langchain4j.qianfan";
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
LanguageModelProperties languageModel;
@NestedConfigurationProperty LanguageModelProperties languageModel;
@NestedConfigurationProperty
LanguageModelProperties streamingLanguageModel;
@NestedConfigurationProperty LanguageModelProperties streamingLanguageModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
}

View File

@@ -1,7 +1,5 @@
package dev.langchain4j.qianfan.spring;
import static dev.langchain4j.qianfan.spring.Properties.PREFIX;
import dev.langchain4j.model.qianfan.QianfanChatModel;
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
import dev.langchain4j.model.qianfan.QianfanLanguageModel;
@@ -12,6 +10,8 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import static dev.langchain4j.qianfan.spring.Properties.PREFIX;
@Configuration
@EnableConfigurationProperties(Properties.class)
public class QianfanAutoConfig {
@@ -111,4 +111,4 @@ public class QianfanAutoConfig {
.logResponses(embeddingModelProperties.getLogResponses())
.build();
}
}
}

View File

@@ -6,11 +6,12 @@ import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public abstract class BaseEmbeddingStoreFactory implements EmbeddingStoreFactory {
protected final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore = new ConcurrentHashMap<>();
protected final Map<String, EmbeddingStore<TextSegment>> collectionNameToStore =
new ConcurrentHashMap<>();
public EmbeddingStore<TextSegment> create(String collectionName) {
return collectionNameToStore.computeIfAbsent(collectionName, this::createEmbeddingStore);
}
public abstract EmbeddingStore<TextSegment> createEmbeddingStore(String collectionName);
}
}

View File

@@ -1,6 +1,5 @@
package dev.langchain4j.store.embedding;
import lombok.Data;
import java.util.Map;
@@ -13,5 +12,4 @@ public class EmbeddingCollection {
private String name;
private Map<String, String> metaData;
}

View File

@@ -5,5 +5,4 @@ import dev.langchain4j.data.segment.TextSegment;
public interface EmbeddingStoreFactory {
EmbeddingStore<TextSegment> create(String collectionName);
}
}

View File

@@ -12,29 +12,39 @@ import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
public class EmbeddingStoreFactoryProvider {
protected static final Map<EmbeddingStoreConfig, EmbeddingStoreFactory> factoryMap = new ConcurrentHashMap<>();
protected static final Map<EmbeddingStoreConfig, EmbeddingStoreFactory> factoryMap =
new ConcurrentHashMap<>();
public static EmbeddingStoreFactory getFactory() {
EmbeddingStoreParameterConfig parameterConfig = ContextUtils.getBean(EmbeddingStoreParameterConfig.class);
EmbeddingStoreParameterConfig parameterConfig =
ContextUtils.getBean(EmbeddingStoreParameterConfig.class);
return getFactory(parameterConfig.convert());
}
public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig embeddingStoreConfig) {
if (embeddingStoreConfig == null || StringUtils.isBlank(embeddingStoreConfig.getProvider())) {
if (embeddingStoreConfig == null
|| StringUtils.isBlank(embeddingStoreConfig.getProvider())) {
return ContextUtils.getBean(EmbeddingStoreFactory.class);
}
if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(embeddingStoreConfig,
return factoryMap.computeIfAbsent(
embeddingStoreConfig,
storeConfig -> new ChromaEmbeddingStoreFactory(storeConfig));
}
if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(embeddingStoreConfig,
return factoryMap.computeIfAbsent(
embeddingStoreConfig,
storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig));
}
if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(embeddingStoreConfig,
if (EmbeddingStoreType.IN_MEMORY
.name()
.equalsIgnoreCase(embeddingStoreConfig.getProvider())) {
return factoryMap.computeIfAbsent(
embeddingStoreConfig,
storeConfig -> new InMemoryEmbeddingStoreFactory(storeConfig));
}
throw new RuntimeException("Unsupported EmbeddingStoreFactory provider: " + embeddingStoreConfig.getProvider());
throw new RuntimeException(
"Unsupported EmbeddingStoreFactory provider: "
+ embeddingStoreConfig.getProvider());
}
}
}

View File

@@ -35,8 +35,9 @@ public class Retrieval {
return false;
}
Retrieval retrieval = (Retrieval) o;
return Double.compare(retrieval.distance, distance) == 0 && Objects.equal(id,
retrieval.id) && Objects.equal(query, retrieval.query)
return Double.compare(retrieval.distance, distance) == 0
&& Objects.equal(id, retrieval.id)
&& Objects.equal(query, retrieval.query)
&& Objects.equal(metadata, retrieval.metadata);
}

View File

@@ -15,6 +15,4 @@ public class RetrieveQuery {
private Map<String, Object> filterCondition;
private List<List<Double>> queryEmbeddings;
}

View File

@@ -1,6 +1,5 @@
package dev.langchain4j.store.embedding;
import lombok.Data;
import java.util.List;
@@ -11,5 +10,4 @@ public class RetrieveQueryResult {
private String query;
private List<Retrieval> retrieval;
}

View File

@@ -1,6 +1,5 @@
package dev.langchain4j.store.embedding;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.pojo.DataItem;
import dev.langchain4j.data.document.Metadata;
@@ -18,12 +17,20 @@ public class TextSegmentConvert {
public static final String QUERY_ID = "queryId";
public static List<TextSegment> convertToEmbedding(List<DataItem> dataItems) {
return dataItems.stream().map(dataItem -> {
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
TextSegment textSegment = TextSegment.from(dataItem.getName(), new Metadata(meta));
addQueryId(textSegment, dataItem.getId() + dataItem.getType().name().toLowerCase());
return textSegment;
}).collect(Collectors.toList());
return dataItems.stream()
.map(
dataItem -> {
Map meta =
JSONObject.parseObject(
JSONObject.toJSONString(dataItem), Map.class);
TextSegment textSegment =
TextSegment.from(dataItem.getName(), new Metadata(meta));
addQueryId(
textSegment,
dataItem.getId() + dataItem.getType().name().toLowerCase());
return textSegment;
})
.collect(Collectors.toList());
}
public static void addQueryId(TextSegment textSegment, String queryId) {

View File

@@ -39,15 +39,17 @@ import static java.util.stream.Collectors.toList;
/**
* An {@link EmbeddingStore} that stores embeddings in memory.
* <p>
* Uses a brute force approach by iterating over all embeddings to find the best matches.
* <p>
* This store can be persisted using the {@link #serializeToJson()} and {@link #serializeToFile(Path)} methods.
* <p>
* It can also be recreated from JSON or a file using the {@link #fromJson(String)} and {@link #fromFile(Path)} methods.
*
* @param <Embedded> The class of the object that has been embedded.
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
* <p>Uses a brute force approach by iterating over all embeddings to find the best matches.
*
* <p>This store can be persisted using the {@link #serializeToJson()} and {@link
* #serializeToFile(Path)} methods.
*
* <p>It can also be recreated from JSON or a file using the {@link #fromJson(String)} and {@link
* #fromFile(Path)} methods.
*
* @param <Embedded> The class of the object that has been embedded. Typically, it is {@link
* dev.langchain4j.data.segment.TextSegment}.
*/
public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
@@ -80,17 +82,16 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
entries.addAll(newEntries);
return newEntries.stream()
.map(entry -> entry.id)
.collect(toList());
return newEntries.stream().map(entry -> entry.id).collect(toList());
}
@Override
public List<String> addAll(List<Embedding> embeddings) {
List<Entry<Embedded>> newEntries = embeddings.stream()
.map(embedding -> new Entry<Embedded>(randomUUID(), embedding))
.collect(toList());
List<Entry<Embedded>> newEntries =
embeddings.stream()
.map(embedding -> new Entry<Embedded>(randomUUID(), embedding))
.collect(toList());
return add(newEntries);
}
@@ -98,12 +99,15 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
@Override
public List<String> addAll(List<Embedding> embeddings, List<Embedded> embedded) {
if (embeddings.size() != embedded.size()) {
throw new IllegalArgumentException("The list of embeddings and embedded must have the same size");
throw new IllegalArgumentException(
"The list of embeddings and embedded must have the same size");
}
List<Entry<Embedded>> newEntries = IntStream.range(0, embeddings.size())
.mapToObj(i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i)))
.collect(toList());
List<Entry<Embedded>> newEntries =
IntStream.range(0, embeddings.size())
.mapToObj(
i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i)))
.collect(toList());
return add(newEntries);
}
@@ -119,15 +123,16 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
public void removeAll(Filter filter) {
ensureNotNull(filter, "filter");
entries.removeIf(entry -> {
if (entry.embedded instanceof TextSegment) {
return filter.test(((TextSegment) entry.embedded).metadata());
} else if (entry.embedded == null) {
return false;
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
});
entries.removeIf(
entry -> {
if (entry.embedded instanceof TextSegment) {
return filter.test(((TextSegment) entry.embedded).metadata());
} else if (entry.embedded == null) {
return false;
} else {
throw new UnsupportedOperationException("Not supported yet.");
}
});
}
@Override
@@ -152,8 +157,9 @@ public class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded
}
}
double cosineSimilarity = CosineSimilarity.between(entry.embedding,
embeddingSearchRequest.queryEmbedding());
double cosineSimilarity =
CosineSimilarity.between(
entry.embedding, embeddingSearchRequest.queryEmbedding());
double score = RelevanceScore.fromCosineSimilarity(cosineSimilarity);
if (score >= embeddingSearchRequest.minScore()) {
matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded));

View File

@@ -42,12 +42,10 @@ import static java.util.Collections.singletonList;
import static java.util.stream.Collectors.toList;
/**
* Represents an <a href="https://milvus.io/">Milvus</a> index as an embedding store.
* <br>
* Supports both local and <a href="https://zilliz.com/">managed</a> Milvus instances.
* <br>
* Supports storing {@link Metadata} and filtering by it using a {@link Filter}
* (provided inside an {@link EmbeddingSearchRequest}).
* Represents an <a href="https://milvus.io/">Milvus</a> index as an embedding store. <br>
* Supports both local and <a href="https://zilliz.com/">managed</a> Milvus instances. <br>
* Supports storing {@link Metadata} and filtering by it using a {@link Filter} (provided inside an
* {@link EmbeddingSearchRequest}).
*/
public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
@@ -78,15 +76,14 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
ConsistencyLevelEnum consistencyLevel,
Boolean retrieveEmbeddingsOnSearch,
Boolean autoFlushOnInsert,
String databaseName
) {
ConnectParam.Builder connectBuilder = ConnectParam
.newBuilder()
.withHost(getOrDefault(host, "localhost"))
.withPort(getOrDefault(port, 19530))
.withUri(uri)
.withToken(token)
.withAuthorization(username, password);
String databaseName) {
ConnectParam.Builder connectBuilder =
ConnectParam.newBuilder()
.withHost(getOrDefault(host, "localhost"))
.withPort(getOrDefault(port, 19530))
.withUri(uri)
.withToken(token)
.withAuthorization(username, password);
if (databaseName != null) {
connectBuilder.withDatabaseName(databaseName);
@@ -99,8 +96,13 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
this.retrieveEmbeddingsOnSearch = getOrDefault(retrieveEmbeddingsOnSearch, false);
this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false);
if (!hasCollection(milvusClient, this.collectionName)) {
createCollection(milvusClient, this.collectionName, ensureNotNull(dimension, "dimension"));
createIndex(milvusClient, this.collectionName, getOrDefault(indexType, FLAT), this.metricType);
createCollection(
milvusClient, this.collectionName, ensureNotNull(dimension, "dimension"));
createIndex(
milvusClient,
this.collectionName,
getOrDefault(indexType, FLAT),
this.metricType);
}
loadCollectionInMemory(milvusClient, collectionName);
@@ -139,30 +141,33 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
@Override
public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest embeddingSearchRequest) {
public EmbeddingSearchResult<TextSegment> search(
EmbeddingSearchRequest embeddingSearchRequest) {
SearchParam searchParam = buildSearchRequest(
collectionName,
embeddingSearchRequest.queryEmbedding().vectorAsList(),
embeddingSearchRequest.filter(),
embeddingSearchRequest.maxResults(),
metricType,
consistencyLevel
);
SearchParam searchParam =
buildSearchRequest(
collectionName,
embeddingSearchRequest.queryEmbedding().vectorAsList(),
embeddingSearchRequest.filter(),
embeddingSearchRequest.maxResults(),
metricType,
consistencyLevel);
SearchResultsWrapper resultsWrapper = CollectionOperationsExecutor.search(milvusClient, searchParam);
SearchResultsWrapper resultsWrapper =
CollectionOperationsExecutor.search(milvusClient, searchParam);
List<EmbeddingMatch<TextSegment>> matches = toEmbeddingMatches(
milvusClient,
resultsWrapper,
collectionName,
consistencyLevel,
retrieveEmbeddingsOnSearch
);
List<EmbeddingMatch<TextSegment>> matches =
toEmbeddingMatches(
milvusClient,
resultsWrapper,
collectionName,
consistencyLevel,
retrieveEmbeddingsOnSearch);
List<EmbeddingMatch<TextSegment>> result = matches.stream()
.filter(match -> match.score() >= embeddingSearchRequest.minScore())
.collect(toList());
List<EmbeddingMatch<TextSegment>> result =
matches.stream()
.filter(match -> match.score() >= embeddingSearchRequest.minScore())
.collect(toList());
return new EmbeddingSearchResult<>(result);
}
@@ -171,15 +176,17 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
addAllInternal(
singletonList(id),
singletonList(embedding),
textSegment == null ? null : singletonList(textSegment)
);
textSegment == null ? null : singletonList(textSegment));
}
private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {
private void addAllInternal(
List<String> ids, List<Embedding> embeddings, List<TextSegment> textSegments) {
List<InsertParam.Field> fields = new ArrayList<>();
fields.add(new InsertParam.Field(ID_FIELD_NAME, ids));
fields.add(new InsertParam.Field(TEXT_FIELD_NAME, toScalars(textSegments, ids.size())));
fields.add(new InsertParam.Field(METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size())));
fields.add(
new InsertParam.Field(
METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size())));
fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, toVectors(embeddings)));
insert(milvusClient, collectionName, fields);
@@ -210,8 +217,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
private String databaseName;
/**
* @param host The host of the self-managed Milvus instance.
* Default value: "localhost".
* @param host The host of the self-managed Milvus instance. Default value: "localhost".
* @return builder
*/
public Builder host(String host) {
@@ -220,8 +226,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param port The port of the self-managed Milvus instance.
* Default value: 19530.
* @param port The port of the self-managed Milvus instance. Default value: 19530.
* @return builder
*/
public Builder port(Integer port) {
@@ -230,9 +235,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param collectionName The name of the Milvus collection.
* If there is no such collection yet, it will be created automatically.
* Default value: "default".
* @param collectionName The name of the Milvus collection. If there is no such collection
* yet, it will be created automatically. Default value: "default".
* @return builder
*/
public Builder collectionName(String collectionName) {
@@ -241,8 +245,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param dimension The dimension of the embedding vector. (e.g. 384)
* Mandatory if a new collection should be created.
* @param dimension The dimension of the embedding vector. (e.g. 384) Mandatory if a new
* collection should be created.
* @return builder
*/
public Builder dimension(Integer dimension) {
@@ -251,8 +255,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param indexType The type of the index.
* Default value: FLAT.
* @param indexType The type of the index. Default value: FLAT.
* @return builder
*/
public Builder indexType(IndexType indexType) {
@@ -261,8 +264,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param metricType The type of the metric used for similarity search.
* Default value: COSINE.
* @param metricType The type of the metric used for similarity search. Default value:
* COSINE.
* @return builder
*/
public Builder metricType(MetricType metricType) {
@@ -271,7 +274,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param uri The URI of the managed Milvus instance. (e.g. "https://xxx.api.gcp-us-west1.zillizcloud.com")
* @param uri The URI of the managed Milvus instance. (e.g.
* "https://xxx.api.gcp-us-west1.zillizcloud.com")
* @return builder
*/
public Builder uri(String uri) {
@@ -289,7 +293,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param username The username. See details <a href="https://milvus.io/docs/authenticate.md">here</a>.
* @param username The username. See details <a
* href="https://milvus.io/docs/authenticate.md">here</a>.
* @return builder
*/
public Builder username(String username) {
@@ -298,7 +303,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param password The password. See details <a href="https://milvus.io/docs/authenticate.md">here</a>.
* @param password The password. See details <a
* href="https://milvus.io/docs/authenticate.md">here</a>.
* @return builder
*/
public Builder password(String password) {
@@ -307,8 +313,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param consistencyLevel The consistency level used by Milvus.
* Default value: EVENTUALLY.
* @param consistencyLevel The consistency level used by Milvus. Default value: EVENTUALLY.
* @return builder
*/
public Builder consistencyLevel(ConsistencyLevelEnum consistencyLevel) {
@@ -317,12 +322,11 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling findRelevant()),
* the embedding itself is not retrieved.
* To retrieve the embedding, an additional query is required.
* Setting this parameter to "true" will ensure that embedding is retrieved.
* Be aware that this will impact the performance of the search.
* Default value: false.
* @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling
* findRelevant()), the embedding itself is not retrieved. To retrieve the embedding, an
* additional query is required. Setting this parameter to "true" will ensure that
* embedding is retrieved. Be aware that this will impact the performance of the search.
* Default value: false.
* @return builder
*/
public Builder retrieveEmbeddingsOnSearch(Boolean retrieveEmbeddingsOnSearch) {
@@ -331,11 +335,10 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param autoFlushOnInsert Whether to automatically flush after each insert
* ({@code add(...)} or {@code addAll(...)} methods).
* Default value: false.
* More info can be found
* <a href="https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/flush.md">here</a>.
* @param autoFlushOnInsert Whether to automatically flush after each insert ({@code
* add(...)} or {@code addAll(...)} methods). Default value: false. More info can be
* found <a
* href="https://milvus.io/api-reference/pymilvus/v2.4.x/ORM/Collection/flush.md">here</a>.
* @return builder
*/
public Builder autoFlushOnInsert(Boolean autoFlushOnInsert) {
@@ -344,8 +347,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
}
/**
* @param databaseName Milvus name of database.
* Default value: null. In this case default Milvus database name will be used.
* @param databaseName Milvus name of database. Default value: null. In this case default
* Milvus database name will be used.
* @return builder
*/
public Builder databaseName(String databaseName) {
@@ -368,8 +371,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore<TextSegment> {
consistencyLevel,
retrieveEmbeddingsOnSearch,
autoFlushOnInsert,
databaseName
);
databaseName);
}
}
}

View File

@@ -16,4 +16,4 @@ class ChatModelProperties {
Integer maxToken;
Boolean logRequests;
Boolean logResponses;
}
}

View File

@@ -13,4 +13,4 @@ class EmbeddingModelProperties {
Integer maxRetries;
Boolean logRequests;
Boolean logResponses;
}
}

View File

@@ -12,12 +12,9 @@ public class Properties {
static final String PREFIX = "langchain4j.zhipu";
@NestedConfigurationProperty
ChatModelProperties chatModel;
@NestedConfigurationProperty ChatModelProperties chatModel;
@NestedConfigurationProperty
ChatModelProperties streamingChatModel;
@NestedConfigurationProperty ChatModelProperties streamingChatModel;
@NestedConfigurationProperty
EmbeddingModelProperties embeddingModel;
}
@NestedConfigurationProperty EmbeddingModelProperties embeddingModel;
}

View File

@@ -1,7 +1,5 @@
package dev.langchain4j.zhipu.spring;
import static dev.langchain4j.zhipu.spring.Properties.PREFIX;
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
import dev.langchain4j.model.zhipu.ZhipuAiStreamingChatModel;
@@ -10,6 +8,8 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import static dev.langchain4j.zhipu.spring.Properties.PREFIX;
@Configuration
@EnableConfigurationProperties(Properties.class)
public class ZhipuAutoConfig {
@@ -60,4 +60,4 @@ public class ZhipuAutoConfig {
.logResponses(embeddingModelProperties.getLogResponses())
.build();
}
}
}