mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 04:57:28 +00:00
(feat)Support for managing large models with Dify #1830;2、add user access token; #1829; 3、support change password #1824 (#1839)
This commit is contained in:
@@ -2,8 +2,15 @@ package com.tencent.supersonic.common.pojo;
|
||||
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.provider.*;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.DifyModelFactory;
|
||||
import dev.langchain4j.provider.LocalAiModelFactory;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import dev.langchain4j.provider.OllamaModelFactory;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import dev.langchain4j.provider.QianfanModelFactory;
|
||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
@@ -52,7 +59,7 @@ public class ChatModelParameters {
|
||||
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
|
||||
LocalAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER);
|
||||
AzureModelFactory.PROVIDER, DifyModelFactory.PROVIDER);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
@@ -63,20 +70,23 @@ public class ChatModelParameters {
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL));
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
||||
DifyModelFactory.PROVIDER, DifyModelFactory.DEFAULT_BASE_URL));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER),
|
||||
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
||||
DifyModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), QianfanModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), ZhipuModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), LocalAiModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), AzureModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DashscopeModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DifyModelFactory.PROVIDER,
|
||||
ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
||||
}
|
||||
|
||||
@@ -88,7 +98,8 @@ public class ChatModelParameters {
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
|
||||
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME));
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME,
|
||||
DifyModelFactory.PROVIDER, DifyModelFactory.DEFAULT_MODEL_NAME));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
package com.tencent.supersonic.common.util;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
@Slf4j
|
||||
public class DifyClient {
|
||||
private static final String DEFAULT_USER = "zhaodongsheng";
|
||||
private static final String CONTENT_TYPE_JSON = "application/json";
|
||||
|
||||
private String difyURL;
|
||||
private String difyKey;
|
||||
|
||||
public DifyClient(String difyURL, String difyKey) {
|
||||
this.difyURL = difyURL;
|
||||
this.difyKey = difyKey;
|
||||
}
|
||||
|
||||
public DifyResult generate(String prompt) {
|
||||
Map<String, String> headers = defaultHeaders();
|
||||
DifyRequest request = new DifyRequest();
|
||||
request.setQuery(prompt);
|
||||
request.setUser(DEFAULT_USER);
|
||||
return sendRequest(request, headers);
|
||||
}
|
||||
|
||||
public DifyResult generate(String prompt, String user) {
|
||||
Map<String, String> headers = defaultHeaders();
|
||||
DifyRequest request = new DifyRequest();
|
||||
request.setQuery(prompt);
|
||||
request.setUser(user);
|
||||
return sendRequest(request, headers);
|
||||
}
|
||||
|
||||
public DifyResult generate(Map<String, String> inputs, String queryText, String user,
|
||||
String conversationId) {
|
||||
Map<String, String> headers = defaultHeaders();
|
||||
DifyRequest request = new DifyRequest();
|
||||
request.setInputs(inputs);
|
||||
request.setQuery(queryText);
|
||||
request.setUser(user);
|
||||
if (conversationId != null && !conversationId.isEmpty()) {
|
||||
request.setConversationId(conversationId);
|
||||
}
|
||||
return sendRequest(request, headers);
|
||||
}
|
||||
|
||||
public DifyResult sendRequest(DifyRequest request, Map<String, String> headers) {
|
||||
try {
|
||||
log.debug("请求dify- header--->" + JsonUtil.toString(headers));
|
||||
log.debug("请求dify- conversionId--->" + JsonUtil.toString(request));
|
||||
return HttpUtils.post(difyURL, JsonUtil.toString(request), headers, DifyResult.class);
|
||||
} catch (Exception e) {
|
||||
log.error("请求dify失败---->" + e.getMessage());
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public String parseSQLResult(String sql) {
|
||||
Pattern pattern = Pattern.compile("```(sql)?(.*)```", Pattern.DOTALL);
|
||||
Matcher matcher = pattern.matcher(sql);
|
||||
if (!matcher.find()) {
|
||||
return sql.trim();
|
||||
} else {
|
||||
return matcher.group(2).trim();
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, String> defaultHeaders() {
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
if (difyKey.contains("Bearer")) {
|
||||
headers.put("Authorization", difyKey);
|
||||
} else {
|
||||
headers.put("Authorization", "Bearer " + difyKey);
|
||||
}
|
||||
|
||||
headers.put("Content-Type", CONTENT_TYPE_JSON);
|
||||
return headers;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.common.util;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonProperty;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
public class DifyRequest {
|
||||
private String query;
|
||||
private Map<String, String> inputs = new HashMap<>();
|
||||
private String responseMode = "blocking";
|
||||
private String user;
|
||||
@JsonProperty("conversation_id")
|
||||
private String conversationId;
|
||||
@JsonProperty("auto_generate_name")
|
||||
private Boolean autoGenerateName = false;
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.common.util;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DifyResult {
|
||||
private String event = "";
|
||||
private String taskId = "";
|
||||
private String conversationId = "";
|
||||
private String id = "";
|
||||
private String messageId = "";
|
||||
private String answer = "";
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
package com.tencent.supersonic.common.util;
|
||||
|
||||
import okhttp3.Dispatcher;
|
||||
import okhttp3.MediaType;
|
||||
import okhttp3.OkHttpClient;
|
||||
import okhttp3.Request;
|
||||
import okhttp3.RequestBody;
|
||||
import okhttp3.Response;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class HttpUtils {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(HttpUtils.class);
|
||||
|
||||
// 重试参考:okhttp3.RealCall.getResponseWithInterceptorChain
|
||||
private static final OkHttpClient client = new OkHttpClient.Builder()
|
||||
.readTimeout(3, TimeUnit.MINUTES).retryOnConnectionFailure(true).build();
|
||||
|
||||
static {
|
||||
Dispatcher dispatcher = client.dispatcher();
|
||||
dispatcher.setMaxRequestsPerHost(300);
|
||||
dispatcher.setMaxRequests(200);
|
||||
}
|
||||
|
||||
public static Response execute(String url) throws IOException {
|
||||
Request request = new Request.Builder().url(url).build();
|
||||
return client.newCall(request).execute();
|
||||
}
|
||||
|
||||
public static String get(String url) throws IOException {
|
||||
return doRequest(builder(url).build());
|
||||
}
|
||||
|
||||
public static String get(String url, Map<String, String> headers) throws IOException {
|
||||
return doRequest(headerBuilder(url, headers).build());
|
||||
}
|
||||
|
||||
public static String get(String url, Map<String, String> headers, Map<String, Object> params)
|
||||
throws IOException {
|
||||
return doRequest(headerBuilder(url + buildUrlParams(params), headers).build());
|
||||
}
|
||||
|
||||
public static <T> T get(String url, Class<T> classOfT) throws IOException {
|
||||
return doRequest(builder(url).build(), classOfT);
|
||||
}
|
||||
|
||||
public static <T> T get(String url, Map<String, String> headers, Class<T> classOfT)
|
||||
throws IOException {
|
||||
return doRequest(headerBuilder(url, headers).build(), classOfT);
|
||||
}
|
||||
|
||||
public static <T> T get(String url, Map<String, String> headers, Map<String, Object> params,
|
||||
Class<T> classOfT) throws IOException {
|
||||
return doRequest(headerBuilder(url + buildUrlParams(params), headers).build(), classOfT);
|
||||
}
|
||||
|
||||
// public static <T> T get(String url, TypeReference<T> type) throws IOException {
|
||||
// return doRequest(builder(url).build(), type);
|
||||
// }
|
||||
|
||||
// public static <T> T get(String url, Map<String, String> headers, TypeReference<T> type)
|
||||
// throws IOException {
|
||||
// return doRequest(headerBuilder(url, headers).build(), type);
|
||||
// }
|
||||
|
||||
// public static <T> T get(String url, Map<String, String> headers, Map<String, Object> params,
|
||||
// TypeReference<T> type) throws IOException {
|
||||
// return doRequest(headerBuilder(url + buildUrlParams(params), headers).build(), type);
|
||||
// }
|
||||
|
||||
public static String post(String url, Object body) throws IOException {
|
||||
return doRequest(postRequest(url, body));
|
||||
}
|
||||
|
||||
public static String post(String url, Object body, Map<String, String> headers)
|
||||
throws IOException {
|
||||
return doRequest(postRequest(url, body, headers));
|
||||
}
|
||||
|
||||
public static <T> T post(String url, Object body, Class<T> classOfT) throws IOException {
|
||||
return doRequest(postRequest(url, body), classOfT);
|
||||
}
|
||||
|
||||
// public static <T> T post(String url, Object body, TypeReference<T> type) throws IOException {
|
||||
// return doRequest(postRequest(url, body), type);
|
||||
// }
|
||||
|
||||
public static <T> T post(String url, Object body, Map<String, String> headers,
|
||||
Class<T> classOfT) throws IOException {
|
||||
return doRequest(postRequest(url, body, headers), classOfT);
|
||||
}
|
||||
|
||||
// public static <T> T post(String url, Object body, Map<String, String> headers,
|
||||
// TypeReference<T> type) throws IOException {
|
||||
// return doRequest(postRequest(url, body, headers), type);
|
||||
// }
|
||||
|
||||
private static Request postRequest(String url, Object body) {
|
||||
return builder(url).post(buildRequestBody(body, null)).build();
|
||||
}
|
||||
|
||||
private static Request postRequest(String url, Object body, Map<String, String> headers) {
|
||||
return headerBuilder(url, headers).post(buildRequestBody(body, headers)).build();
|
||||
}
|
||||
|
||||
private static Request.Builder builder(String url) {
|
||||
return new Request.Builder().url(url);
|
||||
}
|
||||
|
||||
private static Request.Builder headerBuilder(String url, Map<String, String> headers) {
|
||||
Request.Builder builder = new Request.Builder().url(url);
|
||||
headers.forEach(builder::addHeader);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
private static <T> T doRequest(Request request, Class<T> classOfT) throws IOException {
|
||||
return JsonUtil.toObject(doRequest(request), classOfT);
|
||||
}
|
||||
|
||||
// private static <T> T doRequest(Request request, TypeReference<T> type) throws IOException {
|
||||
// return JsonUtil.toObject(doRequest(request), type);
|
||||
// }
|
||||
|
||||
private static String doRequest(Request request) throws IOException {
|
||||
long beginTime = System.currentTimeMillis();
|
||||
|
||||
try {
|
||||
Response response = client.newCall(request).execute();
|
||||
if (response.isSuccessful()) {
|
||||
return response.body().string();
|
||||
} else {
|
||||
throw new RuntimeException(
|
||||
"Http请求失败[" + response.code() + "]:" + response.body().string() + "...");
|
||||
}
|
||||
} finally {
|
||||
logger.info("begin to request : {}, execute costs(ms) : {}", request.url(),
|
||||
System.currentTimeMillis() - beginTime);
|
||||
}
|
||||
}
|
||||
|
||||
private static RequestBody buildRequestBody(Object body, Map<String, String> headers) {
|
||||
if (headers != null && headers.containsKey("Content-Type")) {
|
||||
String contentType = headers.get("Content-Type");
|
||||
return RequestBody.create(MediaType.parse(contentType), body.toString());
|
||||
}
|
||||
|
||||
if (body instanceof String && ((String) body).contains("=")) {
|
||||
return RequestBody.create(MediaType.parse("application/x-www-form-urlencoded"),
|
||||
(String) body);
|
||||
}
|
||||
|
||||
return RequestBody.create(MediaType.parse("application/json"), JsonUtil.toString(body));
|
||||
}
|
||||
|
||||
private static String buildUrlParams(Map<String, Object> params) {
|
||||
if (params.isEmpty()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
return "?" + params.entrySet().stream().map(it -> it.getKey() + "=" + it.getValue())
|
||||
.collect(Collectors.joining("&"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package dev.langchain4j.model.dify;
|
||||
|
||||
import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
||||
import com.tencent.supersonic.common.util.DifyClient;
|
||||
import com.tencent.supersonic.common.util.DifyResult;
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
public class DifyAiChatModel implements ChatLanguageModel {
|
||||
|
||||
private static final String CONTENT_TYPE_JSON = "application/json";
|
||||
|
||||
private final String baseUrl;
|
||||
private final String apiKey;
|
||||
|
||||
private final DifyClient difyClient;
|
||||
private final Integer maxRetries;
|
||||
private final Integer maxToken;
|
||||
|
||||
private final String appName;
|
||||
private final Double temperature;
|
||||
private final Long timeOut;
|
||||
|
||||
private String userName;
|
||||
|
||||
@Builder
|
||||
public DifyAiChatModel(String baseUrl, String apiKey, Integer maxRetries, Integer maxToken,
|
||||
String modelName, Double temperature, Long timeOut) {
|
||||
this.baseUrl = baseUrl;
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.maxToken = getOrDefault(maxToken, 512);
|
||||
try {
|
||||
this.apiKey = AESEncryptionUtil.aesDecryptECB(apiKey);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
this.appName = modelName;
|
||||
this.temperature = temperature;
|
||||
this.timeOut = timeOut;
|
||||
this.difyClient = new DifyClient(this.baseUrl, this.apiKey);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String generate(String message) {
|
||||
DifyResult difyResult = this.difyClient.generate(message, this.getUserName());
|
||||
return difyResult.getAnswer().toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
||||
return generate(messages, (ToolSpecification) null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
ensureNotEmpty(messages, "messages");
|
||||
DifyResult difyResult =
|
||||
this.difyClient.generate(messages.get(0).text(), this.getUserName());
|
||||
System.out.println(difyResult.toString());
|
||||
|
||||
if (!isNullOrEmpty(toolSpecifications)) {
|
||||
// TODO
|
||||
}
|
||||
|
||||
return Response.from(AiMessage.from(difyResult.getAnswer()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
ToolSpecification toolSpecification) {
|
||||
return generate(messages,
|
||||
toolSpecification != null ? singletonList(toolSpecification) : null);
|
||||
}
|
||||
|
||||
public void setUserName(String userName) {
|
||||
this.userName = userName;
|
||||
}
|
||||
|
||||
public String getUserName() {
|
||||
return null == userName ? "zhaodongsheng" : userName;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.dify.DifyAiChatModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class DifyModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "DIFY";
|
||||
|
||||
public static final String DEFAULT_BASE_URL = "https://dify.com/v1/chat-messages";
|
||||
public static final String DEFAULT_MODEL_NAME = "demo-预留-可不填写";
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "all-minilm";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return DifyAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(AESEncryptionUtil.aesDecryptECB(modelConfig.getApiKey()))
|
||||
.modelName(modelConfig.getModelName()).timeOut(modelConfig.getTimeOut()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user