mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-22 14:54:21 +08:00
Compare commits
5 Commits
be8b56bdde
...
3fc1ec42be
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3fc1ec42be | ||
|
|
c4992501bd | ||
|
|
acffc03c79 | ||
|
|
763def2de0 | ||
|
|
d0a67af684 |
@@ -61,7 +61,7 @@ function runJavaService {
|
||||
fi
|
||||
export PATH=$JAVA_HOME/bin:$PATH
|
||||
command="-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08
|
||||
-Dapp_name=${local_app_name} -Xms1024m -Xmx2048m $main_class"
|
||||
-Dapp_name=${local_app_name} -Xms1024m -Xmx2048m -XX:+UseZGC -XX:+ZGenerational $main_class"
|
||||
|
||||
mkdir -p $javaRunDir/logs
|
||||
java -Dspring.profiles.active="$profile" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
||||
|
||||
@@ -21,7 +21,10 @@
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-validation</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-autoconfigure-processor</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
@@ -33,7 +36,7 @@
|
||||
<dependency>
|
||||
<groupId>org.apache.httpcomponents.client5</groupId>
|
||||
<artifactId>httpclient5</artifactId>
|
||||
<version>${httpclient5.version}</version> <!-- 请确认使用最新稳定版本 -->
|
||||
<version>${httpclient5.version}</version>
|
||||
</dependency>
|
||||
<!-- <dependency>-->
|
||||
<!-- <groupId>org.apache.httpcomponents</groupId>-->
|
||||
@@ -182,10 +185,6 @@
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-pgvector</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-azure-open-ai</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
|
||||
@@ -198,34 +197,6 @@
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-qianfan</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-zhipu-ai</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-dashscope</artifactId>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-simple</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-chatglm</artifactId>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-simple</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-ollama</artifactId>
|
||||
@@ -237,11 +208,6 @@
|
||||
<version>${hanlp.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-autoconfigure-processor</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
|
||||
@@ -4,14 +4,10 @@ import com.google.common.collect.ImmutableMap;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||
import dev.langchain4j.provider.OllamaModelFactory;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import dev.langchain4j.provider.QianfanModelFactory;
|
||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -70,52 +66,31 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
|
||||
private static ArrayList<String> getCandidateValues() {
|
||||
return Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER);
|
||||
OllamaModelFactory.PROVIDER);
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
|
||||
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL));
|
||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, AzureModelFactory.PROVIDER, DEMO,
|
||||
DashscopeModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER, DEMO,
|
||||
ZhipuModelFactory.PROVIDER, DEMO));
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
|
||||
OllamaModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.PROVIDER),
|
||||
OllamaModelFactory.PROVIDER),
|
||||
ImmutableMap.of(InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||
OpenAiModelFactory.PROVIDER,
|
||||
OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
OllamaModelFactory.PROVIDER,
|
||||
OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, AzureModelFactory.PROVIDER,
|
||||
AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
DashscopeModelFactory.PROVIDER,
|
||||
DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
QianfanModelFactory.PROVIDER,
|
||||
QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||
ZhipuModelFactory.PROVIDER,
|
||||
ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME));
|
||||
OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getModelPathDependency() {
|
||||
@@ -126,7 +101,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
||||
|
||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO));
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,20 +85,20 @@ public class ChatModelParameters {
|
||||
|
||||
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap
|
||||
.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT));
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER), ImmutableMap
|
||||
.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getEnableSearchDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
|
||||
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false"));
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "false"));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap.of(
|
||||
QianfanModelFactory.PROVIDER, ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER), ImmutableMap.of(
|
||||
OpenAiModelFactory.PROVIDER, ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
||||
}
|
||||
|
||||
private static List<Parameter.Dependency> getDependency(String dependencyParameterName,
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
package dev.langchain4j.dashscope.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class ChatModelProperties {
|
||||
|
||||
String baseUrl;
|
||||
String apiKey;
|
||||
String modelName;
|
||||
Double topP;
|
||||
Integer topK;
|
||||
Boolean enableSearch;
|
||||
Integer seed;
|
||||
Float repetitionPenalty;
|
||||
Float temperature;
|
||||
List<String> stops;
|
||||
Integer maxTokens;
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
package dev.langchain4j.dashscope.spring;
|
||||
|
||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
||||
import dev.langchain4j.model.dashscope.QwenLanguageModel;
|
||||
import dev.langchain4j.model.dashscope.QwenStreamingChatModel;
|
||||
import dev.langchain4j.model.dashscope.QwenStreamingLanguageModel;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import static dev.langchain4j.dashscope.spring.Properties.PREFIX;
|
||||
|
||||
@Configuration
|
||||
@EnableConfigurationProperties(Properties.class)
|
||||
public class DashscopeAutoConfig {
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
||||
QwenChatModel qwenChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
||||
return QwenChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
|
||||
.topK(chatModelProperties.getTopK())
|
||||
.enableSearch(chatModelProperties.getEnableSearch())
|
||||
.seed(chatModelProperties.getSeed())
|
||||
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
||||
QwenStreamingChatModel qwenStreamingChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
||||
return QwenStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
|
||||
.topK(chatModelProperties.getTopK())
|
||||
.enableSearch(chatModelProperties.getEnableSearch())
|
||||
.seed(chatModelProperties.getSeed())
|
||||
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
|
||||
QwenLanguageModel qwenLanguageModel(Properties properties) {
|
||||
ChatModelProperties languageModel = properties.getLanguageModel();
|
||||
return QwenLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
|
||||
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
|
||||
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
|
||||
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
|
||||
.repetitionPenalty(languageModel.getRepetitionPenalty())
|
||||
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
|
||||
.maxTokens(languageModel.getMaxTokens()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
|
||||
QwenStreamingLanguageModel qwenStreamingLanguageModel(Properties properties) {
|
||||
ChatModelProperties languageModel = properties.getStreamingLanguageModel();
|
||||
return QwenStreamingLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
|
||||
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
|
||||
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
|
||||
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
|
||||
.repetitionPenalty(languageModel.getRepetitionPenalty())
|
||||
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
|
||||
.maxTokens(languageModel.getMaxTokens()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
||||
QwenEmbeddingModel qwenEmbeddingModel(Properties properties) {
|
||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
||||
return QwenEmbeddingModel.builder().apiKey(embeddingModelProperties.getApiKey())
|
||||
.modelName(embeddingModelProperties.getModelName()).build();
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
package dev.langchain4j.dashscope.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class EmbeddingModelProperties {
|
||||
|
||||
private String apiKey;
|
||||
private String modelName;
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package dev.langchain4j.dashscope.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@ConfigurationProperties(prefix = Properties.PREFIX)
|
||||
public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.dashscope";
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties chatModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties streamingChatModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties languageModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties streamingLanguageModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingModelProperties embeddingModel;
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.Builder;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -32,6 +33,7 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
||||
private final Double temperature;
|
||||
private final Long timeOut;
|
||||
|
||||
@Setter
|
||||
private String userName;
|
||||
|
||||
@Builder
|
||||
@@ -54,7 +56,7 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
||||
@Override
|
||||
public String generate(String message) {
|
||||
DifyResult difyResult = this.difyClient.generate(message, this.getUserName());
|
||||
return difyResult.getAnswer().toString();
|
||||
return difyResult.getAnswer();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -67,7 +69,7 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
ensureNotEmpty(messages, "messages");
|
||||
DifyResult difyResult =
|
||||
this.difyClient.generate(messages.get(0).text(), this.getUserName());
|
||||
this.difyClient.generate(messages.get(0).toString(), this.getUserName());
|
||||
System.out.println(difyResult.toString());
|
||||
|
||||
if (!isNullOrEmpty(toolSpecifications)) {
|
||||
@@ -84,12 +86,8 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
||||
toolSpecification != null ? singletonList(toolSpecification) : null);
|
||||
}
|
||||
|
||||
public void setUserName(String userName) {
|
||||
this.userName = userName;
|
||||
}
|
||||
|
||||
public String getUserName() {
|
||||
return null == userName ? "zhaodongsheng" : userName;
|
||||
return null == userName ? "admin" : userName;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiResponseFormat;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
|
||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
|
||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO;
|
||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
import static java.time.Duration.ofSeconds;
|
||||
import static java.util.Collections.emptyList;
|
||||
@@ -66,7 +66,6 @@ import static java.util.Collections.singletonList;
|
||||
@Slf4j
|
||||
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
|
||||
public static final String ZHIPU = "bigmodel";
|
||||
private final OpenAiClient client;
|
||||
private final String baseUrl;
|
||||
private final String modelName;
|
||||
@@ -111,7 +110,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
.connectTimeout(timeout).readTimeout(timeout).writeTimeout(timeout).proxy(proxy)
|
||||
.logRequests(logRequests).logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
|
||||
.customHeaders(customHeaders).build();
|
||||
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
|
||||
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO.name());
|
||||
this.apiVersion = apiVersion;
|
||||
this.temperature = getOrDefault(temperature, 0.7);
|
||||
this.topP = topP;
|
||||
@@ -130,7 +129,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
this.strictTools = getOrDefault(strictTools, false);
|
||||
this.parallelToolCalls = parallelToolCalls;
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
|
||||
this.tokenizer = getOrDefault(tokenizer, () -> new OpenAiTokenizer(this.modelName));
|
||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||
}
|
||||
|
||||
@@ -192,9 +191,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||
.responseFormat(responseFormat).seed(seed).user(user)
|
||||
.parallelToolCalls(parallelToolCalls);
|
||||
|
||||
if (!(baseUrl.contains(ZHIPU))) {
|
||||
requestBuilder.temperature(temperature);
|
||||
}
|
||||
requestBuilder.temperature(temperature);
|
||||
|
||||
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
|
||||
requestBuilder.tools(toTools(toolSpecifications, strictTools));
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
package dev.langchain4j.model.zhipu;
|
||||
|
||||
public enum ChatCompletionModel {
|
||||
GLM_4("glm-4"), GLM_3_TURBO("glm-3-turbo"), CHATGLM_TURBO("chatglm_turbo");
|
||||
|
||||
private final String value;
|
||||
|
||||
ChatCompletionModel(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return this.value;
|
||||
}
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
package dev.langchain4j.model.zhipu;
|
||||
|
||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.data.message.ChatMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.zhipu.chat.ChatCompletionRequest;
|
||||
import dev.langchain4j.model.zhipu.chat.ChatCompletionResponse;
|
||||
import dev.langchain4j.model.zhipu.spi.ZhipuAiChatModelBuilderFactory;
|
||||
import lombok.Builder;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.aiMessageFrom;
|
||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.finishReasonFrom;
|
||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toTools;
|
||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toZhipuAiMessages;
|
||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.tokenUsageFrom;
|
||||
import static dev.langchain4j.model.zhipu.chat.ToolChoiceMode.AUTO;
|
||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||
import static java.util.Collections.singletonList;
|
||||
|
||||
/**
|
||||
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and
|
||||
* glm-4. You can find description of parameters
|
||||
* <a href="https://open.bigmodel.cn/dev/api">here</a>.
|
||||
*/
|
||||
public class ZhipuAiChatModel implements ChatLanguageModel {
|
||||
|
||||
private final String baseUrl;
|
||||
private final Double temperature;
|
||||
private final Double topP;
|
||||
private final String model;
|
||||
private final Integer maxRetries;
|
||||
private final Integer maxToken;
|
||||
private final ZhipuAiClient client;
|
||||
|
||||
@Builder
|
||||
public ZhipuAiChatModel(String baseUrl, String apiKey, Double temperature, Double topP,
|
||||
String model, Integer maxRetries, Integer maxToken, Boolean logRequests,
|
||||
Boolean logResponses) {
|
||||
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
|
||||
this.temperature = getOrDefault(temperature, 0.7);
|
||||
this.topP = topP;
|
||||
this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
|
||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||
this.maxToken = getOrDefault(maxToken, 512);
|
||||
this.client = ZhipuAiClient.builder().baseUrl(this.baseUrl).apiKey(apiKey)
|
||||
.logRequests(getOrDefault(logRequests, false))
|
||||
.logResponses(getOrDefault(logResponses, false)).build();
|
||||
}
|
||||
|
||||
public static ZhipuAiChatModelBuilder builder() {
|
||||
for (ZhipuAiChatModelBuilderFactory factories : loadFactories(
|
||||
ZhipuAiChatModelBuilderFactory.class)) {
|
||||
return factories.get();
|
||||
}
|
||||
return new ZhipuAiChatModelBuilder();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
||||
return generate(messages, (ToolSpecification) null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
List<ToolSpecification> toolSpecifications) {
|
||||
ensureNotEmpty(messages, "messages");
|
||||
|
||||
ChatCompletionRequest.Builder requestBuilder =
|
||||
ChatCompletionRequest.builder().model(this.model).maxTokens(maxToken).stream(false)
|
||||
.topP(topP).toolChoice(AUTO).messages(toZhipuAiMessages(messages));
|
||||
|
||||
if (!isNullOrEmpty(toolSpecifications)) {
|
||||
requestBuilder.tools(toTools(toolSpecifications));
|
||||
}
|
||||
|
||||
ChatCompletionResponse response =
|
||||
withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
|
||||
return Response.from(aiMessageFrom(response), tokenUsageFrom(response.getUsage()),
|
||||
finishReasonFrom(response.getChoices().get(0).getFinishReason()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
||||
ToolSpecification toolSpecification) {
|
||||
return generate(messages,
|
||||
toolSpecification != null ? singletonList(toolSpecification) : null);
|
||||
}
|
||||
|
||||
public static class ZhipuAiChatModelBuilder {
|
||||
public ZhipuAiChatModelBuilder() {}
|
||||
}
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.azure.AzureOpenAiChatModel;
|
||||
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class AzureModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "AZURE";
|
||||
public static final String DEFAULT_BASE_URL = "https://your-resource-name.openai.azure.com/";
|
||||
public static final String DEFAULT_MODEL_NAME = "gpt-35-turbo";
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
||||
.endpoint(modelConfig.getBaseUrl()).apiKey(modelConfig.getApiKey())
|
||||
.deploymentName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature()).maxRetries(modelConfig.getMaxRetries())
|
||||
.topP(modelConfig.getTopP())
|
||||
.timeout(Duration.ofSeconds(
|
||||
modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
|
||||
.logRequestsAndResponses(
|
||||
modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
AzureOpenAiEmbeddingModel.Builder builder =
|
||||
AzureOpenAiEmbeddingModel.builder().endpoint(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.deploymentName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null
|
||||
&& embeddingModelConfig.getLogResponses());
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
||||
import dev.langchain4j.model.dashscope.QwenModelName;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "DASHSCOPE";
|
||||
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
|
||||
public static final String DEFAULT_MODEL_NAME = QwenModelName.QWEN_PLUS;
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-v2";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return QwenChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey()).modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature() == null ? 0L
|
||||
: modelConfig.getTemperature().floatValue())
|
||||
.topP(modelConfig.getTopP()).enableSearch(modelConfig.getEnableSearch()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return QwenEmbeddingModel.builder().apiKey(embeddingModelConfig.getApiKey())
|
||||
.modelName(embeddingModelConfig.getModelName()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.dify.DifyAiChatModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -27,8 +27,9 @@ public class DifyModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
|
||||
return OpenAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanChatModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
public static final String PROVIDER = "QIANFAN";
|
||||
public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
|
||||
public static final String DEFAULT_MODEL_NAME = "Llama-2-70b-chat";
|
||||
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "Embedding-V1";
|
||||
public static final String DEFAULT_ENDPOINT = "llama_2_70b";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return QianfanChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey()).secretKey(modelConfig.getSecretKey())
|
||||
.endpoint(modelConfig.getEndpoint()).modelName(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.secretKey(embeddingModelConfig.getSecretKey())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.zhipu.ChatCompletionModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import static java.time.Duration.ofSeconds;
|
||||
|
||||
@Service
|
||||
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "ZHIPU";
|
||||
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/";
|
||||
public static final String DEFAULT_MODEL_NAME = ChatCompletionModel.GLM_4.toString();
|
||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "embedding-2";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return ZhipuAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.apiKey(modelConfig.getApiKey()).model(modelConfig.getModelName())
|
||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries()).callTimeout(ofSeconds(60))
|
||||
.connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60))
|
||||
.readTimeout(ofSeconds(60)).logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(PROVIDER, this);
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package dev.langchain4j.qianfan.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class ChatModelProperties {
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String secretKey;
|
||||
private Double temperature;
|
||||
private Integer maxRetries;
|
||||
private Double topP;
|
||||
private String modelName;
|
||||
private String endpoint;
|
||||
private String responseFormat;
|
||||
private Double penaltyScore;
|
||||
private Boolean logRequests;
|
||||
private Boolean logResponses;
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package dev.langchain4j.qianfan.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class EmbeddingModelProperties {
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String secretKey;
|
||||
private Integer maxRetries;
|
||||
private String modelName;
|
||||
private String endpoint;
|
||||
private String user;
|
||||
private Boolean logRequests;
|
||||
private Boolean logResponses;
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package dev.langchain4j.qianfan.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class LanguageModelProperties {
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String secretKey;
|
||||
private Double temperature;
|
||||
private Integer maxRetries;
|
||||
private Integer topK;
|
||||
private Double topP;
|
||||
private String modelName;
|
||||
private String endpoint;
|
||||
private Double penaltyScore;
|
||||
private Boolean logRequests;
|
||||
private Boolean logResponses;
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package dev.langchain4j.qianfan.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@ConfigurationProperties(prefix = Properties.PREFIX)
|
||||
public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.qianfan";
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties chatModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties streamingChatModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
LanguageModelProperties languageModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
LanguageModelProperties streamingLanguageModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingModelProperties embeddingModel;
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
package dev.langchain4j.qianfan.spring;
|
||||
|
||||
import dev.langchain4j.model.qianfan.QianfanChatModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanLanguageModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanStreamingChatModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanStreamingLanguageModel;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import static dev.langchain4j.qianfan.spring.Properties.PREFIX;
|
||||
|
||||
@Configuration
|
||||
@EnableConfigurationProperties(Properties.class)
|
||||
public class QianfanAutoConfig {
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
||||
QianfanChatModel qianfanChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
||||
return QianfanChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.secretKey(chatModelProperties.getSecretKey())
|
||||
.endpoint(chatModelProperties.getEndpoint())
|
||||
.penaltyScore(chatModelProperties.getPenaltyScore())
|
||||
.modelName(chatModelProperties.getModelName())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP())
|
||||
.responseFormat(chatModelProperties.getResponseFormat())
|
||||
.maxRetries(chatModelProperties.getMaxRetries())
|
||||
.logRequests(chatModelProperties.getLogRequests())
|
||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
||||
QianfanStreamingChatModel qianfanStreamingChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
||||
return QianfanStreamingChatModel.builder().endpoint(chatModelProperties.getEndpoint())
|
||||
.penaltyScore(chatModelProperties.getPenaltyScore())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP()).baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey())
|
||||
.secretKey(chatModelProperties.getSecretKey())
|
||||
.modelName(chatModelProperties.getModelName())
|
||||
.responseFormat(chatModelProperties.getResponseFormat())
|
||||
.logRequests(chatModelProperties.getLogRequests())
|
||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
|
||||
QianfanLanguageModel qianfanLanguageModel(Properties properties) {
|
||||
LanguageModelProperties languageModelProperties = properties.getLanguageModel();
|
||||
return QianfanLanguageModel.builder().endpoint(languageModelProperties.getEndpoint())
|
||||
.penaltyScore(languageModelProperties.getPenaltyScore())
|
||||
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
|
||||
.baseUrl(languageModelProperties.getBaseUrl())
|
||||
.apiKey(languageModelProperties.getApiKey())
|
||||
.secretKey(languageModelProperties.getSecretKey())
|
||||
.modelName(languageModelProperties.getModelName())
|
||||
.temperature(languageModelProperties.getTemperature())
|
||||
.maxRetries(languageModelProperties.getMaxRetries())
|
||||
.logRequests(languageModelProperties.getLogRequests())
|
||||
.logResponses(languageModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
|
||||
QianfanStreamingLanguageModel qianfanStreamingLanguageModel(Properties properties) {
|
||||
LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel();
|
||||
return QianfanStreamingLanguageModel.builder()
|
||||
.endpoint(languageModelProperties.getEndpoint())
|
||||
.penaltyScore(languageModelProperties.getPenaltyScore())
|
||||
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
|
||||
.baseUrl(languageModelProperties.getBaseUrl())
|
||||
.apiKey(languageModelProperties.getApiKey())
|
||||
.secretKey(languageModelProperties.getSecretKey())
|
||||
.modelName(languageModelProperties.getModelName())
|
||||
.temperature(languageModelProperties.getTemperature())
|
||||
.maxRetries(languageModelProperties.getMaxRetries())
|
||||
.logRequests(languageModelProperties.getLogRequests())
|
||||
.logResponses(languageModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
||||
QianfanEmbeddingModel qianfanEmbeddingModel(Properties properties) {
|
||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
||||
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
|
||||
.endpoint(embeddingModelProperties.getEndpoint())
|
||||
.apiKey(embeddingModelProperties.getApiKey())
|
||||
.secretKey(embeddingModelProperties.getSecretKey())
|
||||
.modelName(embeddingModelProperties.getModelName())
|
||||
.user(embeddingModelProperties.getUser())
|
||||
.maxRetries(embeddingModelProperties.getMaxRetries())
|
||||
.logRequests(embeddingModelProperties.getLogRequests())
|
||||
.logResponses(embeddingModelProperties.getLogResponses()).build();
|
||||
}
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package dev.langchain4j.zhipu.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class ChatModelProperties {
|
||||
|
||||
String baseUrl;
|
||||
String apiKey;
|
||||
Double temperature;
|
||||
Double topP;
|
||||
String modelName;
|
||||
Integer maxRetries;
|
||||
Integer maxToken;
|
||||
Boolean logRequests;
|
||||
Boolean logResponses;
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
package dev.langchain4j.zhipu.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class EmbeddingModelProperties {
|
||||
|
||||
String baseUrl;
|
||||
String apiKey;
|
||||
String model;
|
||||
Integer maxRetries;
|
||||
Boolean logRequests;
|
||||
Boolean logResponses;
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
package dev.langchain4j.zhipu.spring;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@ConfigurationProperties(prefix = Properties.PREFIX)
|
||||
public class Properties {
|
||||
|
||||
static final String PREFIX = "langchain4j.zhipu";
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties chatModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
ChatModelProperties streamingChatModel;
|
||||
|
||||
@NestedConfigurationProperty
|
||||
EmbeddingModelProperties embeddingModel;
|
||||
}
|
||||
@@ -1,53 +0,0 @@
|
||||
package dev.langchain4j.zhipu.spring;
|
||||
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiStreamingChatModel;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import static dev.langchain4j.zhipu.spring.Properties.PREFIX;
|
||||
|
||||
@Configuration
|
||||
@EnableConfigurationProperties(Properties.class)
|
||||
public class ZhipuAutoConfig {
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
||||
ZhipuAiChatModel zhipuAiChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
||||
return ZhipuAiChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP()).maxRetries(chatModelProperties.getMaxRetries())
|
||||
.maxToken(chatModelProperties.getMaxToken())
|
||||
.logRequests(chatModelProperties.getLogRequests())
|
||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
||||
ZhipuAiStreamingChatModel zhipuStreamingChatModel(Properties properties) {
|
||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
||||
return ZhipuAiStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
||||
.apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName())
|
||||
.temperature(chatModelProperties.getTemperature())
|
||||
.topP(chatModelProperties.getTopP()).maxToken(chatModelProperties.getMaxToken())
|
||||
.logRequests(chatModelProperties.getLogRequests())
|
||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
||||
ZhipuAiEmbeddingModel zhipuEmbeddingModel(Properties properties) {
|
||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
|
||||
.apiKey(embeddingModelProperties.getApiKey())
|
||||
.model(embeddingModelProperties.getModel())
|
||||
.maxRetries(embeddingModelProperties.getMaxRetries())
|
||||
.logRequests(embeddingModelProperties.getLogRequests())
|
||||
.logResponses(embeddingModelProperties.getLogResponses()).build();
|
||||
}
|
||||
}
|
||||
@@ -8,20 +8,15 @@ import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.provider.AzureModelFactory;
|
||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
||||
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||
import dev.langchain4j.provider.QianfanModelFactory;
|
||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
||||
import org.junit.jupiter.api.Disabled;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junit.jupiter.api.TestInstance;
|
||||
|
||||
import static org.junit.Assert.assertNotNull;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
@Disabled
|
||||
@@ -40,65 +35,6 @@ public class ModelProviderTest extends BaseApplication {
|
||||
assertNotNull(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_qianfan_chat_model() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(QianfanModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(QianfanModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(QianfanModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setSecretKey(ParameterConfig.DEMO);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
modelConfig.setEndpoint(QianfanModelFactory.DEFAULT_ENDPOINT);
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
chatModel.generate("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_zhipu_chat_model() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(ZhipuModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(ZhipuModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(ZhipuModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5");
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
chatModel.generate("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_dashscope_chat_model() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(DashscopeModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(DashscopeModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(DashscopeModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setEnableSearch(true);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
chatModel.generate("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_azure_chat_model() {
|
||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
||||
modelConfig.setProvider(AzureModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(AzureModelFactory.DEFAULT_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(AzureModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
chatModel.generate("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_in_memory_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
@@ -122,61 +58,4 @@ public class ModelProviderTest extends BaseApplication {
|
||||
Response<Embedding> embed = embeddingModel.embed("hi");
|
||||
assertNotNull(embed);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_azure_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(AzureModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(AzureModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
embeddingModel.embed("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_dashscope_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(DashscopeModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(DashscopeModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
embeddingModel.embed("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_qianfan_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(QianfanModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(QianfanModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
||||
modelConfig.setSecretKey(ParameterConfig.DEMO);
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
embeddingModel.embed("hi");
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_zhipu_embedding_model() {
|
||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||
modelConfig.setProvider(ZhipuModelFactory.PROVIDER);
|
||||
modelConfig.setModelName(ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
||||
modelConfig.setBaseUrl(ZhipuModelFactory.DEFAULT_BASE_URL);
|
||||
modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5");
|
||||
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
||||
assertThrows(RuntimeException.class, () -> {
|
||||
embeddingModel.embed("hi");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
25
pom.xml
25
pom.xml
@@ -146,31 +146,11 @@
|
||||
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
|
||||
<version>${langchain4j.embedding.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-azure-open-ai</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
||||
<version>${langchain4j.embedding.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-qianfan</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-zhipu-ai</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-dashscope</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-milvus</artifactId>
|
||||
@@ -186,11 +166,6 @@
|
||||
<artifactId>langchain4j-pgvector</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-chatglm</artifactId>
|
||||
<version>${langchain4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-ollama</artifactId>
|
||||
|
||||
@@ -81,6 +81,7 @@ export enum MsgContentTypeEnum {
|
||||
METRIC_TREND = 'METRIC_TREND',
|
||||
METRIC_BAR = 'METRIC_BAR',
|
||||
MARKDOWN = 'MARKDOWN',
|
||||
METRIC_PIE = 'METRIC_PIE',
|
||||
}
|
||||
|
||||
export enum ChatContextTypeQueryTypeEnum {
|
||||
|
||||
@@ -8,6 +8,7 @@ import Loading from './Loading';
|
||||
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
|
||||
import { solarizedlight } from 'react-syntax-highlighter/dist/esm/styles/prism';
|
||||
import React, { ReactNode, useState } from 'react';
|
||||
import ReactMarkdown from 'react-markdown';
|
||||
|
||||
type Props = {
|
||||
queryId?: number;
|
||||
@@ -122,9 +123,11 @@ const ExecuteItem: React.FC<Props> = ({
|
||||
)}
|
||||
</div>
|
||||
<div>
|
||||
{[MsgContentTypeEnum.METRIC_TREND, MsgContentTypeEnum.METRIC_BAR].includes(
|
||||
msgContentType as MsgContentTypeEnum
|
||||
) && (
|
||||
{[
|
||||
MsgContentTypeEnum.METRIC_TREND,
|
||||
MsgContentTypeEnum.METRIC_BAR,
|
||||
MsgContentTypeEnum.METRIC_PIE,
|
||||
].includes(msgContentType as MsgContentTypeEnum) && (
|
||||
<Switch
|
||||
checkedChildren="表格"
|
||||
unCheckedChildren="表格"
|
||||
@@ -151,7 +154,7 @@ const ExecuteItem: React.FC<Props> = ({
|
||||
{data.textSummary && (
|
||||
<p className={`${prefixCls}-step-title`}>
|
||||
<span style={{ marginRight: 5 }}>总结:</span>
|
||||
{data.textSummary}
|
||||
<ReactMarkdown>{data.textSummary}</ReactMarkdown>
|
||||
</p>
|
||||
)}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ import { useExportByEcharts } from '../../../hooks';
|
||||
|
||||
type Props = {
|
||||
data: MsgDataType;
|
||||
question: string;
|
||||
question?: string;
|
||||
triggerResize?: boolean;
|
||||
loading: boolean;
|
||||
metricField: ColumnType;
|
||||
@@ -32,7 +32,7 @@ type Props = {
|
||||
|
||||
const BarChart: React.FC<Props> = ({
|
||||
data,
|
||||
question,
|
||||
question="",
|
||||
triggerResize,
|
||||
loading,
|
||||
metricField,
|
||||
|
||||
123
webapp/packages/chat-sdk/src/components/ChatMsg/Pie/PieChart.tsx
Normal file
123
webapp/packages/chat-sdk/src/components/ChatMsg/Pie/PieChart.tsx
Normal file
@@ -0,0 +1,123 @@
|
||||
import { PREFIX_CLS, THEME_COLOR_LIST } from '../../../common/constants';
|
||||
import { MsgDataType } from '../../../common/type';
|
||||
import { formatByDecimalPlaces, getFormattedValue } from '../../../utils/utils';
|
||||
import type { ECharts } from 'echarts';
|
||||
import * as echarts from 'echarts';
|
||||
import { useEffect, useRef } from 'react';
|
||||
import { ColumnType } from '../../../common/type';
|
||||
|
||||
type Props = {
|
||||
data: MsgDataType;
|
||||
metricField: ColumnType;
|
||||
categoryField: ColumnType;
|
||||
triggerResize?: boolean;
|
||||
};
|
||||
|
||||
const PieChart: React.FC<Props> = ({
|
||||
data,
|
||||
metricField,
|
||||
categoryField,
|
||||
triggerResize,
|
||||
}) => {
|
||||
const chartRef = useRef<any>();
|
||||
const instanceRef = useRef<ECharts>();
|
||||
|
||||
const { queryResults } = data;
|
||||
const categoryColumnName = categoryField?.bizName || '';
|
||||
const metricColumnName = metricField?.bizName || '';
|
||||
|
||||
const renderChart = () => {
|
||||
let instanceObj: any;
|
||||
if (!instanceRef.current) {
|
||||
instanceObj = echarts.init(chartRef.current);
|
||||
instanceRef.current = instanceObj;
|
||||
} else {
|
||||
instanceObj = instanceRef.current;
|
||||
}
|
||||
|
||||
const data = queryResults || [];
|
||||
const seriesData = data.map((item, index) => {
|
||||
const value = item[metricColumnName];
|
||||
const name = item[categoryColumnName] !== undefined ? item[categoryColumnName] : '未知';
|
||||
return {
|
||||
name,
|
||||
value,
|
||||
itemStyle: {
|
||||
color: THEME_COLOR_LIST[index % THEME_COLOR_LIST.length],
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
instanceObj.setOption({
|
||||
tooltip: {
|
||||
trigger: 'item',
|
||||
formatter: function (params: any) {
|
||||
const value = params.value;
|
||||
return `${params.name}: ${
|
||||
metricField.dataFormatType === 'percent'
|
||||
? `${formatByDecimalPlaces(
|
||||
metricField.dataFormat?.needMultiply100 ? +value * 100 : value,
|
||||
metricField.dataFormat?.decimalPlaces || 2
|
||||
)}%`
|
||||
: getFormattedValue(value)
|
||||
}`;
|
||||
},
|
||||
},
|
||||
legend: {
|
||||
orient: 'vertical',
|
||||
left: 'left',
|
||||
type: 'scroll',
|
||||
data: seriesData.map(item => item.name),
|
||||
selectedMode: true,
|
||||
textStyle: {
|
||||
color: '#666',
|
||||
},
|
||||
},
|
||||
series: [
|
||||
{
|
||||
name: '占比',
|
||||
type: 'pie',
|
||||
radius: ['40%', '70%'],
|
||||
avoidLabelOverlap: false,
|
||||
itemStyle: {
|
||||
borderRadius: 10,
|
||||
borderColor: '#fff',
|
||||
borderWidth: 2,
|
||||
},
|
||||
label: {
|
||||
show: false,
|
||||
position: 'center',
|
||||
},
|
||||
emphasis: {
|
||||
label: {
|
||||
show: true,
|
||||
fontSize: '14',
|
||||
fontWeight: 'bold',
|
||||
},
|
||||
},
|
||||
labelLine: {
|
||||
show: false,
|
||||
},
|
||||
data: seriesData,
|
||||
},
|
||||
],
|
||||
});
|
||||
instanceObj.resize();
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (queryResults && queryResults.length > 0) {
|
||||
renderChart();
|
||||
}
|
||||
}, [queryResults, metricField, categoryField]);
|
||||
|
||||
useEffect(() => {
|
||||
if (triggerResize && instanceRef.current) {
|
||||
instanceRef.current.resize();
|
||||
}
|
||||
}, [triggerResize]);
|
||||
|
||||
return <div className={`${PREFIX_CLS}-pie-chart`} ref={chartRef} />;
|
||||
};
|
||||
|
||||
export default PieChart;
|
||||
@@ -0,0 +1,88 @@
|
||||
import { PREFIX_CLS } from '../../../common/constants';
|
||||
import { MsgDataType } from '../../../common/type';
|
||||
import { useRef, useState } from 'react';
|
||||
import NoPermissionChart from '../NoPermissionChart';
|
||||
import { ColumnType } from '../../../common/type';
|
||||
import { Spin, Select } from 'antd';
|
||||
import PieChart from './PieChart';
|
||||
import Bar from '../Bar';
|
||||
|
||||
type Props = {
|
||||
data: MsgDataType;
|
||||
question: string;
|
||||
triggerResize?: boolean;
|
||||
loading: boolean;
|
||||
metricField: ColumnType;
|
||||
categoryField: ColumnType;
|
||||
onApplyAuth?: (model: string) => void;
|
||||
};
|
||||
|
||||
const metricChartSelectOptions = [
|
||||
{
|
||||
value: 'pie',
|
||||
label: '饼图',
|
||||
},
|
||||
{
|
||||
value: 'bar',
|
||||
label: '柱状图',
|
||||
},
|
||||
];
|
||||
|
||||
const Pie: React.FC<Props> = ({
|
||||
data,
|
||||
question,
|
||||
triggerResize,
|
||||
loading,
|
||||
metricField,
|
||||
categoryField,
|
||||
onApplyAuth,
|
||||
}) => {
|
||||
const [chartType, setChartType] = useState('pie');
|
||||
const { entityInfo } = data;
|
||||
|
||||
if (metricField && !metricField?.authorized) {
|
||||
return (
|
||||
<NoPermissionChart
|
||||
model={entityInfo?.dataSetInfo?.name || ''}
|
||||
chartType="pieChart"
|
||||
onApplyAuth={onApplyAuth}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const prefixCls = `${PREFIX_CLS}-pie`;
|
||||
|
||||
return (
|
||||
<div className={prefixCls}>
|
||||
<div className={`${prefixCls}-metric-fields ${prefixCls}-metric-field-single`}>
|
||||
{question}
|
||||
</div>
|
||||
<div className={`${prefixCls}-select-options`}>
|
||||
<Select
|
||||
defaultValue="pie"
|
||||
bordered={false}
|
||||
options={metricChartSelectOptions}
|
||||
onChange={(value: string) => setChartType(value)}
|
||||
/>
|
||||
</div>
|
||||
{chartType === 'pie' ? (
|
||||
<PieChart
|
||||
data={data}
|
||||
metricField={metricField}
|
||||
categoryField={categoryField}
|
||||
triggerResize={triggerResize}
|
||||
/>
|
||||
) : (
|
||||
<Bar
|
||||
data={data}
|
||||
triggerResize={triggerResize}
|
||||
loading={loading}
|
||||
metricField={metricField}
|
||||
onApplyAuth={onApplyAuth}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default Pie;
|
||||
@@ -0,0 +1,43 @@
|
||||
@import '../../../styles/index.less';
|
||||
|
||||
@pie-prefix-cls: ~'@{prefix-cls}-pie';
|
||||
|
||||
.@{pie-prefix-cls} {
|
||||
&-select-options {
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
|
||||
&-chart {
|
||||
width: 100%;
|
||||
height: 400px;
|
||||
min-height: 400px;
|
||||
}
|
||||
|
||||
&-metric-fields {
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
flex-wrap: wrap;
|
||||
column-gap: 8px;
|
||||
row-gap: 12px;
|
||||
}
|
||||
|
||||
&-metric-field-single {
|
||||
padding-left: 0;
|
||||
font-weight: 500;
|
||||
cursor: default;
|
||||
font-size: 15px;
|
||||
color: var(--text-color);
|
||||
|
||||
&:hover {
|
||||
color: var(--text-color);
|
||||
}
|
||||
}
|
||||
|
||||
&-indicator-name {
|
||||
font-size: 14px;
|
||||
color: var(--text-color);
|
||||
font-weight: 500;
|
||||
margin-top: 2px;
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import Text from './Text';
|
||||
import DrillDownDimensions from '../DrillDownDimensions';
|
||||
import MetricOptions from '../MetricOptions';
|
||||
import { isMobile } from '../../utils/utils';
|
||||
import Pie from './Pie';
|
||||
|
||||
type Props = {
|
||||
queryId?: number;
|
||||
@@ -120,6 +121,16 @@ const ChatMsg: React.FC<Props> = ({
|
||||
return MsgContentTypeEnum.METRIC_TREND;
|
||||
}
|
||||
|
||||
const isMetricPie =
|
||||
metricFields.length > 0 &&
|
||||
metricFields?.length === 1 &&
|
||||
(isMobile ? dataSource?.length <= 5 : dataSource?.length <= 10) &&
|
||||
dataSource.every(item => item[metricFields[0].bizName] > 0);
|
||||
|
||||
if (isMetricPie) {
|
||||
return MsgContentTypeEnum.METRIC_PIE;
|
||||
}
|
||||
|
||||
const isMetricBar =
|
||||
categoryField?.length > 0 &&
|
||||
metricFields?.length === 1 &&
|
||||
@@ -148,7 +159,7 @@ const ChatMsg: React.FC<Props> = ({
|
||||
[queryColumns.length > 5 ? 'width' : 'minWidth']: queryColumns.length * 150,
|
||||
};
|
||||
}
|
||||
if (type === MsgContentTypeEnum.METRIC_TREND) {
|
||||
if (type === MsgContentTypeEnum.METRIC_TREND || type === MsgContentTypeEnum.METRIC_PIE) {
|
||||
return { width: 'calc(100vw - 410px)' };
|
||||
}
|
||||
};
|
||||
@@ -213,6 +224,21 @@ const ChatMsg: React.FC<Props> = ({
|
||||
metricField={metricFields[0]}
|
||||
/>
|
||||
);
|
||||
case MsgContentTypeEnum.METRIC_PIE:
|
||||
const categoryField = columns.find(item => item.showType === 'CATEGORY');
|
||||
if (!categoryField) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<Pie
|
||||
data={{ ...data, queryColumns: columns, queryResults: dataSource }}
|
||||
question={question}
|
||||
triggerResize={triggerResize}
|
||||
loading={loading}
|
||||
metricField={metricFields[0]}
|
||||
categoryField={categoryField}
|
||||
/>
|
||||
);
|
||||
case MsgContentTypeEnum.MARKDOWN:
|
||||
return (
|
||||
<div style={{ maxHeight: 800 }}>
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
|
||||
@import "../components/ChatMsg/MetricTrend/style.less";
|
||||
|
||||
@import "../components/ChatMsg/Pie/style.less";
|
||||
|
||||
@import "../components/ChatMsg/ApplyAuth/style.less";
|
||||
|
||||
@import "../components/ChatMsg/NoPermissionChart/style.less";
|
||||
|
||||
Reference in New Issue
Block a user