(improvement)(headless) Upgrade to the latest version of langchain4j and add support for embedding deletion operation and reset. (#1660)

This commit is contained in:
lexluo09
2024-09-12 18:16:16 +08:00
committed by GitHub
parent 693356e46a
commit 4b1dab8e4a
16 changed files with 13307 additions and 16497 deletions

View File

@@ -4,15 +4,23 @@ import dev.ai4j.openai4j.OpenAiClient;
import dev.ai4j.openai4j.OpenAiHttpException;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
import dev.ai4j.openai4j.chat.ChatCompletionResponse;
import dev.ai4j.openai4j.chat.ResponseFormat;
import dev.ai4j.openai4j.chat.ResponseFormatType;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.Tokenizer;
import dev.langchain4j.model.chat.Capability;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.TokenCountEstimator;
import dev.langchain4j.model.chat.listener.ChatLanguageModelRequest;
import dev.langchain4j.model.chat.listener.ChatLanguageModelResponse;
import dev.langchain4j.model.listener.ModelListener;
import dev.langchain4j.model.chat.listener.ChatModelErrorContext;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.chat.listener.ChatModelRequest;
import dev.langchain4j.model.chat.listener.ChatModelRequestContext;
import dev.langchain4j.model.chat.listener.ChatModelResponse;
import dev.langchain4j.model.chat.listener.ChatModelResponseContext;
import dev.langchain4j.model.chat.request.ChatRequest;
import dev.langchain4j.model.chat.response.ChatResponse;
import dev.langchain4j.model.openai.spi.OpenAiChatModelBuilderFactory;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
@@ -21,11 +29,17 @@ import lombok.extern.slf4j.Slf4j;
import java.net.Proxy;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import static dev.ai4j.openai4j.chat.ResponseFormatType.JSON_SCHEMA;
import static dev.langchain4j.internal.RetryUtils.withRetry;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.model.chat.Capability.RESPONSE_FORMAT_JSON_SCHEMA;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.DEFAULT_USER_AGENT;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_API_KEY;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.OPENAI_DEMO_URL;
@@ -35,6 +49,7 @@ import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListe
import static dev.langchain4j.model.openai.InternalOpenAiHelper.createModelListenerResponse;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.finishReasonFrom;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiResponseFormat;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
@@ -62,14 +77,15 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
private final Double presencePenalty;
private final Double frequencyPenalty;
private final Map<String, Integer> logitBias;
private final String responseFormat;
private final ResponseFormat responseFormat;
private final Boolean strictJsonSchema;
private final Integer seed;
private final String user;
private final Boolean strictTools;
private final Boolean parallelToolCalls;
private final Integer maxRetries;
private final Tokenizer tokenizer;
private final List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>>
listeners;
private final List<ChatModelListener> listeners;
@Builder
public OpenAiChatModel(
@@ -85,8 +101,11 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
Double frequencyPenalty,
Map<String, Integer> logitBias,
String responseFormat,
Boolean strictJsonSchema,
Integer seed,
String user,
Boolean strictTools,
Boolean parallelToolCalls,
Duration timeout,
Integer maxRetries,
Proxy proxy,
@@ -94,7 +113,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
Boolean logResponses,
Tokenizer tokenizer,
Map<String, String> customHeaders,
List<ModelListener<ChatLanguageModelRequest, ChatLanguageModelResponse>> listeners) {
List<ChatModelListener> listeners) {
baseUrl = getOrDefault(baseUrl, OPENAI_URL);
if (OPENAI_DEMO_API_KEY.equals(apiKey)) {
@@ -127,9 +146,19 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
this.presencePenalty = presencePenalty;
this.frequencyPenalty = frequencyPenalty;
this.logitBias = logitBias;
this.responseFormat = responseFormat;
this.responseFormat =
responseFormat == null
? null
: ResponseFormat.builder()
.type(
ResponseFormatType.valueOf(
responseFormat.toUpperCase(Locale.ROOT)))
.build();
this.strictJsonSchema = getOrDefault(strictJsonSchema, false);
this.seed = seed;
this.user = user;
this.strictTools = getOrDefault(strictTools, false);
this.parallelToolCalls = parallelToolCalls;
this.maxRetries = getOrDefault(maxRetries, 3);
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
@@ -139,27 +168,62 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
return modelName;
}
@Override
public Set<Capability> supportedCapabilities() {
Set<Capability> capabilities = new HashSet<>();
if (responseFormat != null && responseFormat.type() == JSON_SCHEMA) {
capabilities.add(RESPONSE_FORMAT_JSON_SCHEMA);
}
return capabilities;
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return generate(messages, null, null);
return generate(messages, null, null, this.responseFormat);
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, List<ToolSpecification> toolSpecifications) {
return generate(messages, toolSpecifications, null);
return generate(messages, toolSpecifications, null, this.responseFormat);
}
@Override
public Response<AiMessage> generate(
List<ChatMessage> messages, ToolSpecification toolSpecification) {
return generate(messages, singletonList(toolSpecification), toolSpecification);
return generate(
messages, singletonList(toolSpecification), toolSpecification, this.responseFormat);
}
@Override
public ChatResponse chat(ChatRequest request) {
Response<AiMessage> response =
generate(
request.messages(),
request.toolSpecifications(),
null,
getOrDefault(
toOpenAiResponseFormat(request.responseFormat(), strictJsonSchema),
this.responseFormat));
return ChatResponse.builder()
.aiMessage(response.content())
.tokenUsage(response.tokenUsage())
.finishReason(response.finishReason())
.build();
}
private Response<AiMessage> generate(
List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications,
ToolSpecification toolThatMustBeExecuted) {
ToolSpecification toolThatMustBeExecuted,
ResponseFormat responseFormat) {
if (responseFormat != null
&& responseFormat.type() == JSON_SCHEMA
&& responseFormat.jsonSchema() == null) {
responseFormat = null;
}
ChatCompletionRequest.Builder requestBuilder =
ChatCompletionRequest.builder()
.model(modelName)
@@ -172,13 +236,15 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
.logitBias(logitBias)
.responseFormat(responseFormat)
.seed(seed)
.user(user);
.user(user)
.parallelToolCalls(parallelToolCalls);
if (!(baseUrl.contains(ZHIPU))) {
requestBuilder.temperature(temperature);
}
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
requestBuilder.tools(toTools(toolSpecifications));
requestBuilder.tools(toTools(toolSpecifications, strictTools));
}
if (toolThatMustBeExecuted != null) {
requestBuilder.toolChoice(toolThatMustBeExecuted.name());
@@ -186,12 +252,15 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
ChatCompletionRequest request = requestBuilder.build();
ChatLanguageModelRequest modelListenerRequest =
ChatModelRequest modelListenerRequest =
createModelListenerRequest(request, messages, toolSpecifications);
Map<Object, Object> attributes = new ConcurrentHashMap<>();
ChatModelRequestContext requestContext =
new ChatModelRequestContext(modelListenerRequest, attributes);
listeners.forEach(
listener -> {
try {
listener.onRequest(modelListenerRequest);
listener.onRequest(requestContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
@@ -208,13 +277,16 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
finishReasonFrom(
chatCompletionResponse.choices().get(0).finishReason()));
ChatLanguageModelResponse modelListenerResponse =
ChatModelResponse modelListenerResponse =
createModelListenerResponse(
chatCompletionResponse.id(), chatCompletionResponse.model(), response);
ChatModelResponseContext responseContext =
new ChatModelResponseContext(
modelListenerResponse, modelListenerRequest, attributes);
listeners.forEach(
listener -> {
try {
listener.onResponse(modelListenerResponse, modelListenerRequest);
listener.onResponse(responseContext);
} catch (Exception e) {
log.warn("Exception while calling model listener", e);
}
@@ -230,14 +302,18 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
error = e;
}
ChatModelErrorContext errorContext =
new ChatModelErrorContext(error, modelListenerRequest, null, attributes);
listeners.forEach(
listener -> {
try {
listener.onError(error, null, modelListenerRequest);
listener.onError(errorContext);
} catch (Exception e2) {
log.warn("Exception while calling model listener", e2);
}
});
throw e;
}
}
@@ -270,5 +346,10 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
this.modelName = modelName;
return this;
}
public OpenAiChatModelBuilder modelName(OpenAiChatModelName modelName) {
this.modelName = modelName.toString();
return this;
}
}
}