(feat)Support for managing large models with Dify #1830;2、add user access token; #1829; 3、support change password #1824 (#1839)

This commit is contained in:
zhaodongsheng
2024-10-22 13:58:58 +08:00
committed by GitHub
parent bdb20ca462
commit 0ddcdf93ec
34 changed files with 1341 additions and 45 deletions

View File

@@ -2,8 +2,15 @@ package com.tencent.supersonic.common.pojo;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.provider.*;
import dev.langchain4j.provider.AzureModelFactory;
import dev.langchain4j.provider.DashscopeModelFactory;
import dev.langchain4j.provider.DifyModelFactory;
import dev.langchain4j.provider.LocalAiModelFactory;
import dev.langchain4j.provider.ModelProvider;
import dev.langchain4j.provider.OllamaModelFactory;
import dev.langchain4j.provider.OpenAiModelFactory;
import dev.langchain4j.provider.QianfanModelFactory;
import dev.langchain4j.provider.ZhipuModelFactory;
import java.util.ArrayList;
import java.util.List;
@@ -52,7 +59,7 @@ public class ChatModelParameters {
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
LocalAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
AzureModelFactory.PROVIDER);
AzureModelFactory.PROVIDER, DifyModelFactory.PROVIDER);
}
private static List<Parameter.Dependency> getBaseUrlDependency() {
@@ -63,20 +70,23 @@ public class ChatModelParameters {
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL));
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
DifyModelFactory.PROVIDER, DifyModelFactory.DEFAULT_BASE_URL));
}
private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER,
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER),
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
DifyModelFactory.PROVIDER),
ImmutableMap.of(OpenAiModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), QianfanModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), ZhipuModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), LocalAiModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), AzureModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DashscopeModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DifyModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
}
@@ -88,7 +98,8 @@ public class ChatModelParameters {
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_MODEL_NAME,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_MODEL_NAME,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME));
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_MODEL_NAME,
DifyModelFactory.PROVIDER, DifyModelFactory.DEFAULT_MODEL_NAME));
}
private static List<Parameter.Dependency> getEndpointDependency() {

View File

@@ -0,0 +1,85 @@
package com.tencent.supersonic.common.util;
import lombok.extern.slf4j.Slf4j;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@Slf4j
public class DifyClient {
private static final String DEFAULT_USER = "zhaodongsheng";
private static final String CONTENT_TYPE_JSON = "application/json";
private String difyURL;
private String difyKey;
public DifyClient(String difyURL, String difyKey) {
this.difyURL = difyURL;
this.difyKey = difyKey;
}
public DifyResult generate(String prompt) {
Map<String, String> headers = defaultHeaders();
DifyRequest request = new DifyRequest();
request.setQuery(prompt);
request.setUser(DEFAULT_USER);
return sendRequest(request, headers);
}
public DifyResult generate(String prompt, String user) {
Map<String, String> headers = defaultHeaders();
DifyRequest request = new DifyRequest();
request.setQuery(prompt);
request.setUser(user);
return sendRequest(request, headers);
}
public DifyResult generate(Map<String, String> inputs, String queryText, String user,
String conversationId) {
Map<String, String> headers = defaultHeaders();
DifyRequest request = new DifyRequest();
request.setInputs(inputs);
request.setQuery(queryText);
request.setUser(user);
if (conversationId != null && !conversationId.isEmpty()) {
request.setConversationId(conversationId);
}
return sendRequest(request, headers);
}
public DifyResult sendRequest(DifyRequest request, Map<String, String> headers) {
try {
log.debug("请求dify- header--->" + JsonUtil.toString(headers));
log.debug("请求dify- conversionId--->" + JsonUtil.toString(request));
return HttpUtils.post(difyURL, JsonUtil.toString(request), headers, DifyResult.class);
} catch (Exception e) {
log.error("请求dify失败---->" + e.getMessage());
throw new RuntimeException(e);
}
}
public String parseSQLResult(String sql) {
Pattern pattern = Pattern.compile("```(sql)?(.*)```", Pattern.DOTALL);
Matcher matcher = pattern.matcher(sql);
if (!matcher.find()) {
return sql.trim();
} else {
return matcher.group(2).trim();
}
}
private Map<String, String> defaultHeaders() {
Map<String, String> headers = new HashMap<>();
if (difyKey.contains("Bearer")) {
headers.put("Authorization", difyKey);
} else {
headers.put("Authorization", "Bearer " + difyKey);
}
headers.put("Content-Type", CONTENT_TYPE_JSON);
return headers;
}
}

View File

@@ -0,0 +1,19 @@
package com.tencent.supersonic.common.util;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.HashMap;
import java.util.Map;
@Data
public class DifyRequest {
private String query;
private Map<String, String> inputs = new HashMap<>();
private String responseMode = "blocking";
private String user;
@JsonProperty("conversation_id")
private String conversationId;
@JsonProperty("auto_generate_name")
private Boolean autoGenerateName = false;
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.common.util;
import lombok.Data;
@Data
public class DifyResult {
private String event = "";
private String taskId = "";
private String conversationId = "";
private String id = "";
private String messageId = "";
private String answer = "";
}

View File

@@ -0,0 +1,170 @@
package com.tencent.supersonic.common.util;
import okhttp3.Dispatcher;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
public class HttpUtils {
private static final Logger logger = LoggerFactory.getLogger(HttpUtils.class);
// 重试参考okhttp3.RealCall.getResponseWithInterceptorChain
private static final OkHttpClient client = new OkHttpClient.Builder()
.readTimeout(3, TimeUnit.MINUTES).retryOnConnectionFailure(true).build();
static {
Dispatcher dispatcher = client.dispatcher();
dispatcher.setMaxRequestsPerHost(300);
dispatcher.setMaxRequests(200);
}
public static Response execute(String url) throws IOException {
Request request = new Request.Builder().url(url).build();
return client.newCall(request).execute();
}
public static String get(String url) throws IOException {
return doRequest(builder(url).build());
}
public static String get(String url, Map<String, String> headers) throws IOException {
return doRequest(headerBuilder(url, headers).build());
}
public static String get(String url, Map<String, String> headers, Map<String, Object> params)
throws IOException {
return doRequest(headerBuilder(url + buildUrlParams(params), headers).build());
}
public static <T> T get(String url, Class<T> classOfT) throws IOException {
return doRequest(builder(url).build(), classOfT);
}
public static <T> T get(String url, Map<String, String> headers, Class<T> classOfT)
throws IOException {
return doRequest(headerBuilder(url, headers).build(), classOfT);
}
public static <T> T get(String url, Map<String, String> headers, Map<String, Object> params,
Class<T> classOfT) throws IOException {
return doRequest(headerBuilder(url + buildUrlParams(params), headers).build(), classOfT);
}
// public static <T> T get(String url, TypeReference<T> type) throws IOException {
// return doRequest(builder(url).build(), type);
// }
// public static <T> T get(String url, Map<String, String> headers, TypeReference<T> type)
// throws IOException {
// return doRequest(headerBuilder(url, headers).build(), type);
// }
// public static <T> T get(String url, Map<String, String> headers, Map<String, Object> params,
// TypeReference<T> type) throws IOException {
// return doRequest(headerBuilder(url + buildUrlParams(params), headers).build(), type);
// }
public static String post(String url, Object body) throws IOException {
return doRequest(postRequest(url, body));
}
public static String post(String url, Object body, Map<String, String> headers)
throws IOException {
return doRequest(postRequest(url, body, headers));
}
public static <T> T post(String url, Object body, Class<T> classOfT) throws IOException {
return doRequest(postRequest(url, body), classOfT);
}
// public static <T> T post(String url, Object body, TypeReference<T> type) throws IOException {
// return doRequest(postRequest(url, body), type);
// }
public static <T> T post(String url, Object body, Map<String, String> headers,
Class<T> classOfT) throws IOException {
return doRequest(postRequest(url, body, headers), classOfT);
}
// public static <T> T post(String url, Object body, Map<String, String> headers,
// TypeReference<T> type) throws IOException {
// return doRequest(postRequest(url, body, headers), type);
// }
private static Request postRequest(String url, Object body) {
return builder(url).post(buildRequestBody(body, null)).build();
}
private static Request postRequest(String url, Object body, Map<String, String> headers) {
return headerBuilder(url, headers).post(buildRequestBody(body, headers)).build();
}
private static Request.Builder builder(String url) {
return new Request.Builder().url(url);
}
private static Request.Builder headerBuilder(String url, Map<String, String> headers) {
Request.Builder builder = new Request.Builder().url(url);
headers.forEach(builder::addHeader);
return builder;
}
private static <T> T doRequest(Request request, Class<T> classOfT) throws IOException {
return JsonUtil.toObject(doRequest(request), classOfT);
}
// private static <T> T doRequest(Request request, TypeReference<T> type) throws IOException {
// return JsonUtil.toObject(doRequest(request), type);
// }
private static String doRequest(Request request) throws IOException {
long beginTime = System.currentTimeMillis();
try {
Response response = client.newCall(request).execute();
if (response.isSuccessful()) {
return response.body().string();
} else {
throw new RuntimeException(
"Http请求失败[" + response.code() + "]:" + response.body().string() + "...");
}
} finally {
logger.info("begin to request : {}, execute costs(ms) : {}", request.url(),
System.currentTimeMillis() - beginTime);
}
}
private static RequestBody buildRequestBody(Object body, Map<String, String> headers) {
if (headers != null && headers.containsKey("Content-Type")) {
String contentType = headers.get("Content-Type");
return RequestBody.create(MediaType.parse(contentType), body.toString());
}
if (body instanceof String && ((String) body).contains("=")) {
return RequestBody.create(MediaType.parse("application/x-www-form-urlencoded"),
(String) body);
}
return RequestBody.create(MediaType.parse("application/json"), JsonUtil.toString(body));
}
private static String buildUrlParams(Map<String, Object> params) {
if (params.isEmpty()) {
return "";
}
return "?" + params.entrySet().stream().map(it -> it.getKey() + "=" + it.getValue())
.collect(Collectors.joining("&"));
}
}