(feat)Support for managing large models with Dify #1830;2、add user access token; #1829; 3、support change password #1824 (#1839)

This commit is contained in:
zhaodongsheng
2024-10-22 13:58:58 +08:00
committed by GitHub
parent bdb20ca462
commit 0ddcdf93ec
34 changed files with 1341 additions and 45 deletions

View File

@@ -2,8 +2,15 @@ package com.tencent.supersonic.common.pojo;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.provider.*;
import dev.langchain4j.provider.AzureModelFactory;
import dev.langchain4j.provider.DashscopeModelFactory;
import dev.langchain4j.provider.DifyModelFactory;
import dev.langchain4j.provider.LocalAiModelFactory;
import dev.langchain4j.provider.ModelProvider;
import dev.langchain4j.provider.OllamaModelFactory;
import dev.langchain4j.provider.OpenAiModelFactory;
import dev.langchain4j.provider.QianfanModelFactory;
import dev.langchain4j.provider.ZhipuModelFactory;
import java.util.ArrayList;
import java.util.List;
@@ -52,7 +59,7 @@ public class ChatModelParameters {
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
LocalAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
AzureModelFactory.PROVIDER);
AzureModelFactory.PROVIDER, DifyModelFactory.PROVIDER);
}
private static List<Parameter.Dependency> getBaseUrlDependency() {
@@ -63,20 +70,23 @@ public class ChatModelParameters {
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL));
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
DifyModelFactory.PROVIDER, DifyModelFactory.DEFAULT_BASE_URL));
}
private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER,
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER),
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
DifyModelFactory.PROVIDER),
ImmutableMap.of(OpenAiModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), QianfanModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), ZhipuModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), LocalAiModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), AzureModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DashscopeModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DifyModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
}
@@ -88,7 +98,8 @@ public class ChatModelParameters {
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME));
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME,
DifyModelFactory.PROVIDER, DifyModelFactory.DEFAULT_MODEL_NAME));
}
private static List<Parameter.Dependency> getEndpointDependency() {

View File

@@ -0,0 +1,85 @@
package com.tencent.supersonic.common.util;
import lombok.extern.slf4j.Slf4j;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@Slf4j
public class DifyClient {
private static final String DEFAULT_USER = "zhaodongsheng";
private static final String CONTENT_TYPE_JSON = "application/json";
private String difyURL;
private String difyKey;
public DifyClient(String difyURL, String difyKey) {
this.difyURL = difyURL;
this.difyKey = difyKey;
}
public DifyResult generate(String prompt) {
Map<String, String> headers = defaultHeaders();
DifyRequest request = new DifyRequest();
request.setQuery(prompt);
request.setUser(DEFAULT_USER);
return sendRequest(request, headers);
}
public DifyResult generate(String prompt, String user) {
Map<String, String> headers = defaultHeaders();
DifyRequest request = new DifyRequest();
request.setQuery(prompt);
request.setUser(user);
return sendRequest(request, headers);
}
public DifyResult generate(Map<String, String> inputs, String queryText, String user,
String conversationId) {
Map<String, String> headers = defaultHeaders();
DifyRequest request = new DifyRequest();
request.setInputs(inputs);
request.setQuery(queryText);
request.setUser(user);
if (conversationId != null && !conversationId.isEmpty()) {
request.setConversationId(conversationId);
}
return sendRequest(request, headers);
}
public DifyResult sendRequest(DifyRequest request, Map<String, String> headers) {
try {
log.debug("请求dify- header--->" + JsonUtil.toString(headers));
log.debug("请求dify- conversionId--->" + JsonUtil.toString(request));
return HttpUtils.post(difyURL, JsonUtil.toString(request), headers, DifyResult.class);
} catch (Exception e) {
log.error("请求dify失败---->" + e.getMessage());
throw new RuntimeException(e);
}
}
public String parseSQLResult(String sql) {
Pattern pattern = Pattern.compile("```(sql)?(.*)```", Pattern.DOTALL);
Matcher matcher = pattern.matcher(sql);
if (!matcher.find()) {
return sql.trim();
} else {
return matcher.group(2).trim();
}
}
private Map<String, String> defaultHeaders() {
Map<String, String> headers = new HashMap<>();
if (difyKey.contains("Bearer")) {
headers.put("Authorization", difyKey);
} else {
headers.put("Authorization", "Bearer " + difyKey);
}
headers.put("Content-Type", CONTENT_TYPE_JSON);
return headers;
}
}

View File

@@ -0,0 +1,19 @@
package com.tencent.supersonic.common.util;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.HashMap;
import java.util.Map;
@Data
public class DifyRequest {
private String query;
private Map<String, String> inputs = new HashMap<>();
private String responseMode = "blocking";
private String user;
@JsonProperty("conversation_id")
private String conversationId;
@JsonProperty("auto_generate_name")
private Boolean autoGenerateName = false;
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.common.util;
import lombok.Data;
@Data
public class DifyResult {
private String event = "";
private String taskId = "";
private String conversationId = "";
private String id = "";
private String messageId = "";
private String answer = "";
}

View File

@@ -0,0 +1,170 @@
package com.tencent.supersonic.common.util;
import okhttp3.Dispatcher;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
public class HttpUtils {
private static final Logger logger = LoggerFactory.getLogger(HttpUtils.class);
// 重试参考okhttp3.RealCall.getResponseWithInterceptorChain
private static final OkHttpClient client = new OkHttpClient.Builder()
.readTimeout(3, TimeUnit.MINUTES).retryOnConnectionFailure(true).build();
static {
Dispatcher dispatcher = client.dispatcher();
dispatcher.setMaxRequestsPerHost(300);
dispatcher.setMaxRequests(200);
}
public static Response execute(String url) throws IOException {
Request request = new Request.Builder().url(url).build();
return client.newCall(request).execute();
}
public static String get(String url) throws IOException {
return doRequest(builder(url).build());
}
public static String get(String url, Map<String, String> headers) throws IOException {
return doRequest(headerBuilder(url, headers).build());
}
public static String get(String url, Map<String, String> headers, Map<String, Object> params)
throws IOException {
return doRequest(headerBuilder(url + buildUrlParams(params), headers).build());
}
public static <T> T get(String url, Class<T> classOfT) throws IOException {
return doRequest(builder(url).build(), classOfT);
}
public static <T> T get(String url, Map<String, String> headers, Class<T> classOfT)
throws IOException {
return doRequest(headerBuilder(url, headers).build(), classOfT);
}
public static <T> T get(String url, Map<String, String> headers, Map<String, Object> params,
Class<T> classOfT) throws IOException {
return doRequest(headerBuilder(url + buildUrlParams(params), headers).build(), classOfT);
}
// public static <T> T get(String url, TypeReference<T> type) throws IOException {
// return doRequest(builder(url).build(), type);
// }
// public static <T> T get(String url, Map<String, String> headers, TypeReference<T> type)
// throws IOException {
// return doRequest(headerBuilder(url, headers).build(), type);
// }
// public static <T> T get(String url, Map<String, String> headers, Map<String, Object> params,
// TypeReference<T> type) throws IOException {
// return doRequest(headerBuilder(url + buildUrlParams(params), headers).build(), type);
// }
public static String post(String url, Object body) throws IOException {
return doRequest(postRequest(url, body));
}
public static String post(String url, Object body, Map<String, String> headers)
throws IOException {
return doRequest(postRequest(url, body, headers));
}
public static <T> T post(String url, Object body, Class<T> classOfT) throws IOException {
return doRequest(postRequest(url, body), classOfT);
}
// public static <T> T post(String url, Object body, TypeReference<T> type) throws IOException {
// return doRequest(postRequest(url, body), type);
// }
public static <T> T post(String url, Object body, Map<String, String> headers,
Class<T> classOfT) throws IOException {
return doRequest(postRequest(url, body, headers), classOfT);
}
// public static <T> T post(String url, Object body, Map<String, String> headers,
// TypeReference<T> type) throws IOException {
// return doRequest(postRequest(url, body, headers), type);
// }
private static Request postRequest(String url, Object body) {
return builder(url).post(buildRequestBody(body, null)).build();
}
private static Request postRequest(String url, Object body, Map<String, String> headers) {
return headerBuilder(url, headers).post(buildRequestBody(body, headers)).build();
}
private static Request.Builder builder(String url) {
return new Request.Builder().url(url);
}
private static Request.Builder headerBuilder(String url, Map<String, String> headers) {
Request.Builder builder = new Request.Builder().url(url);
headers.forEach(builder::addHeader);
return builder;
}
private static <T> T doRequest(Request request, Class<T> classOfT) throws IOException {
return JsonUtil.toObject(doRequest(request), classOfT);
}
// private static <T> T doRequest(Request request, TypeReference<T> type) throws IOException {
// return JsonUtil.toObject(doRequest(request), type);
// }
private static String doRequest(Request request) throws IOException {
long beginTime = System.currentTimeMillis();
try {
Response response = client.newCall(request).execute();
if (response.isSuccessful()) {
return response.body().string();
} else {
throw new RuntimeException(
"Http请求失败[" + response.code() + "]:" + response.body().string() + "...");
}
} finally {
logger.info("begin to request : {}, execute costs(ms) : {}", request.url(),
System.currentTimeMillis() - beginTime);
}
}
private static RequestBody buildRequestBody(Object body, Map<String, String> headers) {
if (headers != null && headers.containsKey("Content-Type")) {
String contentType = headers.get("Content-Type");
return RequestBody.create(MediaType.parse(contentType), body.toString());
}
if (body instanceof String && ((String) body).contains("=")) {
return RequestBody.create(MediaType.parse("application/x-www-form-urlencoded"),
(String) body);
}
return RequestBody.create(MediaType.parse("application/json"), JsonUtil.toString(body));
}
private static String buildUrlParams(Map<String, Object> params) {
if (params.isEmpty()) {
return "";
}
return "?" + params.entrySet().stream().map(it -> it.getKey() + "=" + it.getValue())
.collect(Collectors.joining("&"));
}
}

View File

@@ -0,0 +1,95 @@
package dev.langchain4j.model.dify;
import com.tencent.supersonic.common.util.AESEncryptionUtil;
import com.tencent.supersonic.common.util.DifyClient;
import com.tencent.supersonic.common.util.DifyResult;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import java.util.List;
import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
import static java.util.Collections.singletonList;
public class DifyAiChatModel implements ChatLanguageModel {
private static final String CONTENT_TYPE_JSON = "application/json";
private final String baseUrl;
private final String apiKey;
private final DifyClient difyClient;
private final Integer maxRetries;
private final Integer maxToken;
private final String appName;
private final Double temperature;
private final Long timeOut;
private String userName;
@Builder
public DifyAiChatModel(String baseUrl, String apiKey, Integer maxRetries, Integer maxToken,
String modelName, Double temperature, Long timeOut) {
this.baseUrl = baseUrl;
this.maxRetries = getOrDefault(maxRetries, 3);
this.maxToken = getOrDefault(maxToken, 512);
try {
this.apiKey = AESEncryptionUtil.aesDecryptECB(apiKey);
} catch (Exception e) {
throw new RuntimeException(e);
}
this.appName = modelName;
this.temperature = temperature;
this.timeOut = timeOut;
this.difyClient = new DifyClient(this.baseUrl, this.apiKey);
}
@Override
public String generate(String message) {
DifyResult difyResult = this.difyClient.generate(message, this.getUserName());
return difyResult.getAnswer().toString();
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {
return generate(messages, (ToolSpecification) null);
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages,
List<ToolSpecification> toolSpecifications) {
ensureNotEmpty(messages, "messages");
DifyResult difyResult =
this.difyClient.generate(messages.get(0).text(), this.getUserName());
System.out.println(difyResult.toString());
if (!isNullOrEmpty(toolSpecifications)) {
// TODO
}
return Response.from(AiMessage.from(difyResult.getAnswer()));
}
@Override
public Response<AiMessage> generate(List<ChatMessage> messages,
ToolSpecification toolSpecification) {
return generate(messages,
toolSpecification != null ? singletonList(toolSpecification) : null);
}
public void setUserName(String userName) {
this.userName = userName;
}
public String getUserName() {
return null == userName ? "zhaodongsheng" : userName;
}
}

View File

@@ -0,0 +1,41 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import com.tencent.supersonic.common.util.AESEncryptionUtil;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.dify.DifyAiChatModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@Service
public class DifyModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "DIFY";
public static final String DEFAULT_BASE_URL = "https://dify.com/v1/chat-messages";
public static final String DEFAULT_MODEL_NAME = "demo-预留-可不填写";
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "all-minilm";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return DifyAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.apiKey(AESEncryptionUtil.aesDecryptECB(modelConfig.getApiKey()))
.modelName(modelConfig.getModelName()).timeOut(modelConfig.getTimeOut()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses()).build();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(PROVIDER, this);
}
}