mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-30 13:04:27 +08:00
Compare commits
5 Commits
be8b56bdde
...
3fc1ec42be
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3fc1ec42be | ||
|
|
c4992501bd | ||
|
|
acffc03c79 | ||
|
|
763def2de0 | ||
|
|
d0a67af684 |
@@ -61,7 +61,7 @@ function runJavaService {
|
|||||||
fi
|
fi
|
||||||
export PATH=$JAVA_HOME/bin:$PATH
|
export PATH=$JAVA_HOME/bin:$PATH
|
||||||
command="-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08
|
command="-Dfile.encoding=UTF-8 -Duser.language=Zh -Duser.region=CN -Duser.timezone=GMT+08
|
||||||
-Dapp_name=${local_app_name} -Xms1024m -Xmx2048m $main_class"
|
-Dapp_name=${local_app_name} -Xms1024m -Xmx2048m -XX:+UseZGC -XX:+ZGenerational $main_class"
|
||||||
|
|
||||||
mkdir -p $javaRunDir/logs
|
mkdir -p $javaRunDir/logs
|
||||||
java -Dspring.profiles.active="$profile" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
java -Dspring.profiles.active="$profile" $command >/dev/null 2>$javaRunDir/logs/error.log &
|
||||||
|
|||||||
@@ -21,7 +21,10 @@
|
|||||||
<groupId>org.springframework.boot</groupId>
|
<groupId>org.springframework.boot</groupId>
|
||||||
<artifactId>spring-boot-starter-validation</artifactId>
|
<artifactId>spring-boot-starter-validation</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-autoconfigure-processor</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.slf4j</groupId>
|
<groupId>org.slf4j</groupId>
|
||||||
@@ -33,7 +36,7 @@
|
|||||||
<dependency>
|
<dependency>
|
||||||
<groupId>org.apache.httpcomponents.client5</groupId>
|
<groupId>org.apache.httpcomponents.client5</groupId>
|
||||||
<artifactId>httpclient5</artifactId>
|
<artifactId>httpclient5</artifactId>
|
||||||
<version>${httpclient5.version}</version> <!-- 请确认使用最新稳定版本 -->
|
<version>${httpclient5.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<!-- <dependency>-->
|
<!-- <dependency>-->
|
||||||
<!-- <groupId>org.apache.httpcomponents</groupId>-->
|
<!-- <groupId>org.apache.httpcomponents</groupId>-->
|
||||||
@@ -182,10 +185,6 @@
|
|||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-pgvector</artifactId>
|
<artifactId>langchain4j-pgvector</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-azure-open-ai</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
|
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
|
||||||
@@ -198,34 +197,6 @@
|
|||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-qianfan</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-zhipu-ai</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-dashscope</artifactId>
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.slf4j</groupId>
|
|
||||||
<artifactId>slf4j-simple</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-chatglm</artifactId>
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>org.slf4j</groupId>
|
|
||||||
<artifactId>slf4j-simple</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-ollama</artifactId>
|
<artifactId>langchain4j-ollama</artifactId>
|
||||||
@@ -237,11 +208,6 @@
|
|||||||
<version>${hanlp.version}</version>
|
<version>${hanlp.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.springframework.boot</groupId>
|
|
||||||
<artifactId>spring-boot-autoconfigure-processor</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.google.code.gson</groupId>
|
<groupId>com.google.code.gson</groupId>
|
||||||
<artifactId>gson</artifactId>
|
<artifactId>gson</artifactId>
|
||||||
|
|||||||
@@ -4,14 +4,10 @@ import com.google.common.collect.ImmutableMap;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Parameter;
|
import com.tencent.supersonic.common.pojo.Parameter;
|
||||||
import dev.langchain4j.provider.AzureModelFactory;
|
|
||||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
|
||||||
import dev.langchain4j.provider.EmbeddingModelConstant;
|
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||||
import dev.langchain4j.provider.OllamaModelFactory;
|
import dev.langchain4j.provider.OllamaModelFactory;
|
||||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||||
import dev.langchain4j.provider.QianfanModelFactory;
|
|
||||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@@ -70,52 +66,31 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
|||||||
|
|
||||||
private static ArrayList<String> getCandidateValues() {
|
private static ArrayList<String> getCandidateValues() {
|
||||||
return Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
|
return Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
|
||||||
OllamaModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
OllamaModelFactory.PROVIDER);
|
||||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
|
|
||||||
AzureModelFactory.PROVIDER);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
private static List<Parameter.Dependency> getBaseUrlDependency() {
|
||||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER),
|
||||||
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
|
|
||||||
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER),
|
|
||||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
|
||||||
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
|
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL));
|
||||||
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
|
|
||||||
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
|
|
||||||
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
|
|
||||||
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getApiKeyDependency() {
|
private static List<Parameter.Dependency> getApiKeyDependency() {
|
||||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(OpenAiModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||||
DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO));
|
||||||
ZhipuModelFactory.PROVIDER),
|
|
||||||
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, AzureModelFactory.PROVIDER, DEMO,
|
|
||||||
DashscopeModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER, DEMO,
|
|
||||||
ZhipuModelFactory.PROVIDER, DEMO));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getModelNameDependency() {
|
private static List<Parameter.Dependency> getModelNameDependency() {
|
||||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
|
Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
|
||||||
OllamaModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
|
OllamaModelFactory.PROVIDER),
|
||||||
DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
|
|
||||||
ZhipuModelFactory.PROVIDER),
|
|
||||||
ImmutableMap.of(InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
ImmutableMap.of(InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
|
||||||
OpenAiModelFactory.PROVIDER,
|
OpenAiModelFactory.PROVIDER,
|
||||||
OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
||||||
OllamaModelFactory.PROVIDER,
|
OllamaModelFactory.PROVIDER,
|
||||||
OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, AzureModelFactory.PROVIDER,
|
OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME));
|
||||||
AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
|
||||||
DashscopeModelFactory.PROVIDER,
|
|
||||||
DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
|
||||||
QianfanModelFactory.PROVIDER,
|
|
||||||
QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
|
|
||||||
ZhipuModelFactory.PROVIDER,
|
|
||||||
ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getModelPathDependency() {
|
private static List<Parameter.Dependency> getModelPathDependency() {
|
||||||
@@ -126,7 +101,7 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
|
|||||||
|
|
||||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||||
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(QianfanModelFactory.PROVIDER),
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||||
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO));
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -85,20 +85,20 @@ public class ChatModelParameters {
|
|||||||
|
|
||||||
private static List<Parameter.Dependency> getEndpointDependency() {
|
private static List<Parameter.Dependency> getEndpointDependency() {
|
||||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER), ImmutableMap
|
||||||
.of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT));
|
.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getEnableSearchDependency() {
|
private static List<Parameter.Dependency> getEnableSearchDependency() {
|
||||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(DashscopeModelFactory.PROVIDER),
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
|
||||||
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false"));
|
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "false"));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
private static List<Parameter.Dependency> getSecretKeyDependency() {
|
||||||
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
return getDependency(CHAT_MODEL_PROVIDER.getName(),
|
||||||
Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap.of(
|
Lists.newArrayList(OpenAiModelFactory.PROVIDER), ImmutableMap.of(
|
||||||
QianfanModelFactory.PROVIDER, ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
OpenAiModelFactory.PROVIDER, ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<Parameter.Dependency> getDependency(String dependencyParameterName,
|
private static List<Parameter.Dependency> getDependency(String dependencyParameterName,
|
||||||
|
|||||||
@@ -1,23 +0,0 @@
|
|||||||
package dev.langchain4j.dashscope.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class ChatModelProperties {
|
|
||||||
|
|
||||||
String baseUrl;
|
|
||||||
String apiKey;
|
|
||||||
String modelName;
|
|
||||||
Double topP;
|
|
||||||
Integer topK;
|
|
||||||
Boolean enableSearch;
|
|
||||||
Integer seed;
|
|
||||||
Float repetitionPenalty;
|
|
||||||
Float temperature;
|
|
||||||
List<String> stops;
|
|
||||||
Integer maxTokens;
|
|
||||||
}
|
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
package dev.langchain4j.dashscope.spring;
|
|
||||||
|
|
||||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenLanguageModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenStreamingChatModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenStreamingLanguageModel;
|
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
|
||||||
import org.springframework.context.annotation.Bean;
|
|
||||||
import org.springframework.context.annotation.Configuration;
|
|
||||||
|
|
||||||
import static dev.langchain4j.dashscope.spring.Properties.PREFIX;
|
|
||||||
|
|
||||||
@Configuration
|
|
||||||
@EnableConfigurationProperties(Properties.class)
|
|
||||||
public class DashscopeAutoConfig {
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
|
||||||
QwenChatModel qwenChatModel(Properties properties) {
|
|
||||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
|
||||||
return QwenChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
|
||||||
.apiKey(chatModelProperties.getApiKey())
|
|
||||||
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
|
|
||||||
.topK(chatModelProperties.getTopK())
|
|
||||||
.enableSearch(chatModelProperties.getEnableSearch())
|
|
||||||
.seed(chatModelProperties.getSeed())
|
|
||||||
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
|
|
||||||
.temperature(chatModelProperties.getTemperature())
|
|
||||||
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
|
||||||
QwenStreamingChatModel qwenStreamingChatModel(Properties properties) {
|
|
||||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
|
||||||
return QwenStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
|
||||||
.apiKey(chatModelProperties.getApiKey())
|
|
||||||
.modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP())
|
|
||||||
.topK(chatModelProperties.getTopK())
|
|
||||||
.enableSearch(chatModelProperties.getEnableSearch())
|
|
||||||
.seed(chatModelProperties.getSeed())
|
|
||||||
.repetitionPenalty(chatModelProperties.getRepetitionPenalty())
|
|
||||||
.temperature(chatModelProperties.getTemperature())
|
|
||||||
.stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens())
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
|
|
||||||
QwenLanguageModel qwenLanguageModel(Properties properties) {
|
|
||||||
ChatModelProperties languageModel = properties.getLanguageModel();
|
|
||||||
return QwenLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
|
|
||||||
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
|
|
||||||
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
|
|
||||||
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
|
|
||||||
.repetitionPenalty(languageModel.getRepetitionPenalty())
|
|
||||||
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
|
|
||||||
.maxTokens(languageModel.getMaxTokens()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
|
|
||||||
QwenStreamingLanguageModel qwenStreamingLanguageModel(Properties properties) {
|
|
||||||
ChatModelProperties languageModel = properties.getStreamingLanguageModel();
|
|
||||||
return QwenStreamingLanguageModel.builder().baseUrl(languageModel.getBaseUrl())
|
|
||||||
.apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName())
|
|
||||||
.topP(languageModel.getTopP()).topK(languageModel.getTopK())
|
|
||||||
.enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed())
|
|
||||||
.repetitionPenalty(languageModel.getRepetitionPenalty())
|
|
||||||
.temperature(languageModel.getTemperature()).stops(languageModel.getStops())
|
|
||||||
.maxTokens(languageModel.getMaxTokens()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
|
||||||
QwenEmbeddingModel qwenEmbeddingModel(Properties properties) {
|
|
||||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
|
||||||
return QwenEmbeddingModel.builder().apiKey(embeddingModelProperties.getApiKey())
|
|
||||||
.modelName(embeddingModelProperties.getModelName()).build();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,12 +0,0 @@
|
|||||||
package dev.langchain4j.dashscope.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class EmbeddingModelProperties {
|
|
||||||
|
|
||||||
private String apiKey;
|
|
||||||
private String modelName;
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package dev.langchain4j.dashscope.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
|
||||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@ConfigurationProperties(prefix = Properties.PREFIX)
|
|
||||||
public class Properties {
|
|
||||||
|
|
||||||
static final String PREFIX = "langchain4j.dashscope";
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties chatModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties streamingChatModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties languageModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties streamingLanguageModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
EmbeddingModelProperties embeddingModel;
|
|
||||||
}
|
|
||||||
@@ -9,6 +9,7 @@ import dev.langchain4j.data.message.ChatMessage;
|
|||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@@ -32,6 +33,7 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
|||||||
private final Double temperature;
|
private final Double temperature;
|
||||||
private final Long timeOut;
|
private final Long timeOut;
|
||||||
|
|
||||||
|
@Setter
|
||||||
private String userName;
|
private String userName;
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
@@ -54,7 +56,7 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
|||||||
@Override
|
@Override
|
||||||
public String generate(String message) {
|
public String generate(String message) {
|
||||||
DifyResult difyResult = this.difyClient.generate(message, this.getUserName());
|
DifyResult difyResult = this.difyClient.generate(message, this.getUserName());
|
||||||
return difyResult.getAnswer().toString();
|
return difyResult.getAnswer();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -67,7 +69,7 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
|||||||
List<ToolSpecification> toolSpecifications) {
|
List<ToolSpecification> toolSpecifications) {
|
||||||
ensureNotEmpty(messages, "messages");
|
ensureNotEmpty(messages, "messages");
|
||||||
DifyResult difyResult =
|
DifyResult difyResult =
|
||||||
this.difyClient.generate(messages.get(0).text(), this.getUserName());
|
this.difyClient.generate(messages.get(0).toString(), this.getUserName());
|
||||||
System.out.println(difyResult.toString());
|
System.out.println(difyResult.toString());
|
||||||
|
|
||||||
if (!isNullOrEmpty(toolSpecifications)) {
|
if (!isNullOrEmpty(toolSpecifications)) {
|
||||||
@@ -84,12 +86,8 @@ public class DifyAiChatModel implements ChatLanguageModel {
|
|||||||
toolSpecification != null ? singletonList(toolSpecification) : null);
|
toolSpecification != null ? singletonList(toolSpecification) : null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setUserName(String userName) {
|
|
||||||
this.userName = userName;
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getUserName() {
|
public String getUserName() {
|
||||||
return null == userName ? "zhaodongsheng" : userName;
|
return null == userName ? "admin" : userName;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiMessages
|
|||||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiResponseFormat;
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toOpenAiResponseFormat;
|
||||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.toTools;
|
||||||
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
|
import static dev.langchain4j.model.openai.InternalOpenAiHelper.tokenUsageFrom;
|
||||||
import static dev.langchain4j.model.openai.OpenAiModelName.GPT_3_5_TURBO;
|
import static dev.langchain4j.model.openai.OpenAiChatModelName.GPT_3_5_TURBO;
|
||||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
||||||
import static java.time.Duration.ofSeconds;
|
import static java.time.Duration.ofSeconds;
|
||||||
import static java.util.Collections.emptyList;
|
import static java.util.Collections.emptyList;
|
||||||
@@ -66,7 +66,6 @@ import static java.util.Collections.singletonList;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
||||||
|
|
||||||
public static final String ZHIPU = "bigmodel";
|
|
||||||
private final OpenAiClient client;
|
private final OpenAiClient client;
|
||||||
private final String baseUrl;
|
private final String baseUrl;
|
||||||
private final String modelName;
|
private final String modelName;
|
||||||
@@ -111,7 +110,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||||||
.connectTimeout(timeout).readTimeout(timeout).writeTimeout(timeout).proxy(proxy)
|
.connectTimeout(timeout).readTimeout(timeout).writeTimeout(timeout).proxy(proxy)
|
||||||
.logRequests(logRequests).logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
|
.logRequests(logRequests).logResponses(logResponses).userAgent(DEFAULT_USER_AGENT)
|
||||||
.customHeaders(customHeaders).build();
|
.customHeaders(customHeaders).build();
|
||||||
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO);
|
this.modelName = getOrDefault(modelName, GPT_3_5_TURBO.name());
|
||||||
this.apiVersion = apiVersion;
|
this.apiVersion = apiVersion;
|
||||||
this.temperature = getOrDefault(temperature, 0.7);
|
this.temperature = getOrDefault(temperature, 0.7);
|
||||||
this.topP = topP;
|
this.topP = topP;
|
||||||
@@ -130,7 +129,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||||||
this.strictTools = getOrDefault(strictTools, false);
|
this.strictTools = getOrDefault(strictTools, false);
|
||||||
this.parallelToolCalls = parallelToolCalls;
|
this.parallelToolCalls = parallelToolCalls;
|
||||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
this.maxRetries = getOrDefault(maxRetries, 3);
|
||||||
this.tokenizer = getOrDefault(tokenizer, OpenAiTokenizer::new);
|
this.tokenizer = getOrDefault(tokenizer, () -> new OpenAiTokenizer(this.modelName));
|
||||||
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
this.listeners = listeners == null ? emptyList() : new ArrayList<>(listeners);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,9 +191,7 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator {
|
|||||||
.responseFormat(responseFormat).seed(seed).user(user)
|
.responseFormat(responseFormat).seed(seed).user(user)
|
||||||
.parallelToolCalls(parallelToolCalls);
|
.parallelToolCalls(parallelToolCalls);
|
||||||
|
|
||||||
if (!(baseUrl.contains(ZHIPU))) {
|
requestBuilder.temperature(temperature);
|
||||||
requestBuilder.temperature(temperature);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
|
if (toolSpecifications != null && !toolSpecifications.isEmpty()) {
|
||||||
requestBuilder.tools(toTools(toolSpecifications, strictTools));
|
requestBuilder.tools(toTools(toolSpecifications, strictTools));
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
package dev.langchain4j.model.zhipu;
|
|
||||||
|
|
||||||
public enum ChatCompletionModel {
|
|
||||||
GLM_4("glm-4"), GLM_3_TURBO("glm-3-turbo"), CHATGLM_TURBO("chatglm_turbo");
|
|
||||||
|
|
||||||
private final String value;
|
|
||||||
|
|
||||||
ChatCompletionModel(String value) {
|
|
||||||
this.value = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return this.value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
package dev.langchain4j.model.zhipu;
|
|
||||||
|
|
||||||
import dev.langchain4j.agent.tool.ToolSpecification;
|
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
|
||||||
import dev.langchain4j.data.message.ChatMessage;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.output.Response;
|
|
||||||
import dev.langchain4j.model.zhipu.chat.ChatCompletionRequest;
|
|
||||||
import dev.langchain4j.model.zhipu.chat.ChatCompletionResponse;
|
|
||||||
import dev.langchain4j.model.zhipu.spi.ZhipuAiChatModelBuilderFactory;
|
|
||||||
import lombok.Builder;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static dev.langchain4j.internal.RetryUtils.withRetry;
|
|
||||||
import static dev.langchain4j.internal.Utils.getOrDefault;
|
|
||||||
import static dev.langchain4j.internal.Utils.isNullOrEmpty;
|
|
||||||
import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty;
|
|
||||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.aiMessageFrom;
|
|
||||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.finishReasonFrom;
|
|
||||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toTools;
|
|
||||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.toZhipuAiMessages;
|
|
||||||
import static dev.langchain4j.model.zhipu.DefaultZhipuAiHelper.tokenUsageFrom;
|
|
||||||
import static dev.langchain4j.model.zhipu.chat.ToolChoiceMode.AUTO;
|
|
||||||
import static dev.langchain4j.spi.ServiceHelper.loadFactories;
|
|
||||||
import static java.util.Collections.singletonList;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and
|
|
||||||
* glm-4. You can find description of parameters
|
|
||||||
* <a href="https://open.bigmodel.cn/dev/api">here</a>.
|
|
||||||
*/
|
|
||||||
public class ZhipuAiChatModel implements ChatLanguageModel {
|
|
||||||
|
|
||||||
private final String baseUrl;
|
|
||||||
private final Double temperature;
|
|
||||||
private final Double topP;
|
|
||||||
private final String model;
|
|
||||||
private final Integer maxRetries;
|
|
||||||
private final Integer maxToken;
|
|
||||||
private final ZhipuAiClient client;
|
|
||||||
|
|
||||||
@Builder
|
|
||||||
public ZhipuAiChatModel(String baseUrl, String apiKey, Double temperature, Double topP,
|
|
||||||
String model, Integer maxRetries, Integer maxToken, Boolean logRequests,
|
|
||||||
Boolean logResponses) {
|
|
||||||
this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/");
|
|
||||||
this.temperature = getOrDefault(temperature, 0.7);
|
|
||||||
this.topP = topP;
|
|
||||||
this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString());
|
|
||||||
this.maxRetries = getOrDefault(maxRetries, 3);
|
|
||||||
this.maxToken = getOrDefault(maxToken, 512);
|
|
||||||
this.client = ZhipuAiClient.builder().baseUrl(this.baseUrl).apiKey(apiKey)
|
|
||||||
.logRequests(getOrDefault(logRequests, false))
|
|
||||||
.logResponses(getOrDefault(logResponses, false)).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ZhipuAiChatModelBuilder builder() {
|
|
||||||
for (ZhipuAiChatModelBuilderFactory factories : loadFactories(
|
|
||||||
ZhipuAiChatModelBuilderFactory.class)) {
|
|
||||||
return factories.get();
|
|
||||||
}
|
|
||||||
return new ZhipuAiChatModelBuilder();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Response<AiMessage> generate(List<ChatMessage> messages) {
|
|
||||||
return generate(messages, (ToolSpecification) null);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
|
||||||
List<ToolSpecification> toolSpecifications) {
|
|
||||||
ensureNotEmpty(messages, "messages");
|
|
||||||
|
|
||||||
ChatCompletionRequest.Builder requestBuilder =
|
|
||||||
ChatCompletionRequest.builder().model(this.model).maxTokens(maxToken).stream(false)
|
|
||||||
.topP(topP).toolChoice(AUTO).messages(toZhipuAiMessages(messages));
|
|
||||||
|
|
||||||
if (!isNullOrEmpty(toolSpecifications)) {
|
|
||||||
requestBuilder.tools(toTools(toolSpecifications));
|
|
||||||
}
|
|
||||||
|
|
||||||
ChatCompletionResponse response =
|
|
||||||
withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries);
|
|
||||||
return Response.from(aiMessageFrom(response), tokenUsageFrom(response.getUsage()),
|
|
||||||
finishReasonFrom(response.getChoices().get(0).getFinishReason()));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Response<AiMessage> generate(List<ChatMessage> messages,
|
|
||||||
ToolSpecification toolSpecification) {
|
|
||||||
return generate(messages,
|
|
||||||
toolSpecification != null ? singletonList(toolSpecification) : null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class ZhipuAiChatModelBuilder {
|
|
||||||
public ZhipuAiChatModelBuilder() {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
package dev.langchain4j.provider;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
|
||||||
import dev.langchain4j.model.azure.AzureOpenAiChatModel;
|
|
||||||
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
import java.time.Duration;
|
|
||||||
|
|
||||||
@Service
|
|
||||||
public class AzureModelFactory implements ModelFactory, InitializingBean {
|
|
||||||
public static final String PROVIDER = "AZURE";
|
|
||||||
public static final String DEFAULT_BASE_URL = "https://your-resource-name.openai.azure.com/";
|
|
||||||
public static final String DEFAULT_MODEL_NAME = "gpt-35-turbo";
|
|
||||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-ada-002";
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
|
||||||
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
|
||||||
.endpoint(modelConfig.getBaseUrl()).apiKey(modelConfig.getApiKey())
|
|
||||||
.deploymentName(modelConfig.getModelName())
|
|
||||||
.temperature(modelConfig.getTemperature()).maxRetries(modelConfig.getMaxRetries())
|
|
||||||
.topP(modelConfig.getTopP())
|
|
||||||
.timeout(Duration.ofSeconds(
|
|
||||||
modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
|
|
||||||
.logRequestsAndResponses(
|
|
||||||
modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
|
|
||||||
return builder.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
|
||||||
AzureOpenAiEmbeddingModel.Builder builder =
|
|
||||||
AzureOpenAiEmbeddingModel.builder().endpoint(embeddingModelConfig.getBaseUrl())
|
|
||||||
.apiKey(embeddingModelConfig.getApiKey())
|
|
||||||
.deploymentName(embeddingModelConfig.getModelName())
|
|
||||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
|
||||||
.logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null
|
|
||||||
&& embeddingModelConfig.getLogResponses());
|
|
||||||
return builder.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void afterPropertiesSet() {
|
|
||||||
ModelProvider.add(PROVIDER, this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
package dev.langchain4j.provider;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.dashscope.QwenModelName;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
@Service
|
|
||||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
|
||||||
public static final String PROVIDER = "DASHSCOPE";
|
|
||||||
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
|
|
||||||
public static final String DEFAULT_MODEL_NAME = QwenModelName.QWEN_PLUS;
|
|
||||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "text-embedding-v2";
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
|
||||||
return QwenChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
|
||||||
.apiKey(modelConfig.getApiKey()).modelName(modelConfig.getModelName())
|
|
||||||
.temperature(modelConfig.getTemperature() == null ? 0L
|
|
||||||
: modelConfig.getTemperature().floatValue())
|
|
||||||
.topP(modelConfig.getTopP()).enableSearch(modelConfig.getEnableSearch()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
|
||||||
return QwenEmbeddingModel.builder().apiKey(embeddingModelConfig.getApiKey())
|
|
||||||
.modelName(embeddingModelConfig.getModelName()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void afterPropertiesSet() {
|
|
||||||
ModelProvider.add(PROVIDER, this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
|||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.dify.DifyAiChatModel;
|
import dev.langchain4j.model.dify.DifyAiChatModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
import org.springframework.beans.factory.InitializingBean;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@@ -27,8 +27,9 @@ public class DifyModelFactory implements ModelFactory, InitializingBean {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
return OpenAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
||||||
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
|
.apiKey(embeddingModelConfig.getApiKey())
|
||||||
|
.modelName(embeddingModelConfig.getModelName())
|
||||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||||
.logRequests(embeddingModelConfig.getLogRequests())
|
.logRequests(embeddingModelConfig.getLogRequests())
|
||||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
||||||
|
|||||||
@@ -1,47 +0,0 @@
|
|||||||
package dev.langchain4j.provider;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanChatModel;
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
@Service
|
|
||||||
public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
|
||||||
|
|
||||||
public static final String PROVIDER = "QIANFAN";
|
|
||||||
public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
|
|
||||||
public static final String DEFAULT_MODEL_NAME = "Llama-2-70b-chat";
|
|
||||||
|
|
||||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "Embedding-V1";
|
|
||||||
public static final String DEFAULT_ENDPOINT = "llama_2_70b";
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
|
||||||
return QianfanChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
|
||||||
.apiKey(modelConfig.getApiKey()).secretKey(modelConfig.getSecretKey())
|
|
||||||
.endpoint(modelConfig.getEndpoint()).modelName(modelConfig.getModelName())
|
|
||||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
|
||||||
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
|
|
||||||
.logResponses(modelConfig.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
|
||||||
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
|
||||||
.apiKey(embeddingModelConfig.getApiKey())
|
|
||||||
.secretKey(embeddingModelConfig.getSecretKey())
|
|
||||||
.modelName(embeddingModelConfig.getModelName())
|
|
||||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
|
||||||
.logRequests(embeddingModelConfig.getLogRequests())
|
|
||||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void afterPropertiesSet() {
|
|
||||||
ModelProvider.add(PROVIDER, this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
package dev.langchain4j.provider;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
|
||||||
import dev.langchain4j.model.zhipu.ChatCompletionModel;
|
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
|
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
import static java.time.Duration.ofSeconds;
|
|
||||||
|
|
||||||
@Service
|
|
||||||
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
|
||||||
public static final String PROVIDER = "ZHIPU";
|
|
||||||
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/";
|
|
||||||
public static final String DEFAULT_MODEL_NAME = ChatCompletionModel.GLM_4.toString();
|
|
||||||
public static final String DEFAULT_EMBEDDING_MODEL_NAME = "embedding-2";
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
|
||||||
return ZhipuAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
|
||||||
.apiKey(modelConfig.getApiKey()).model(modelConfig.getModelName())
|
|
||||||
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
|
||||||
.maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests())
|
|
||||||
.logResponses(modelConfig.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
|
||||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())
|
|
||||||
.apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName())
|
|
||||||
.maxRetries(embeddingModelConfig.getMaxRetries()).callTimeout(ofSeconds(60))
|
|
||||||
.connectTimeout(ofSeconds(60)).writeTimeout(ofSeconds(60))
|
|
||||||
.readTimeout(ofSeconds(60)).logRequests(embeddingModelConfig.getLogRequests())
|
|
||||||
.logResponses(embeddingModelConfig.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void afterPropertiesSet() {
|
|
||||||
ModelProvider.add(PROVIDER, this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package dev.langchain4j.qianfan.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class ChatModelProperties {
|
|
||||||
private String baseUrl;
|
|
||||||
private String apiKey;
|
|
||||||
private String secretKey;
|
|
||||||
private Double temperature;
|
|
||||||
private Integer maxRetries;
|
|
||||||
private Double topP;
|
|
||||||
private String modelName;
|
|
||||||
private String endpoint;
|
|
||||||
private String responseFormat;
|
|
||||||
private Double penaltyScore;
|
|
||||||
private Boolean logRequests;
|
|
||||||
private Boolean logResponses;
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package dev.langchain4j.qianfan.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class EmbeddingModelProperties {
|
|
||||||
private String baseUrl;
|
|
||||||
private String apiKey;
|
|
||||||
private String secretKey;
|
|
||||||
private Integer maxRetries;
|
|
||||||
private String modelName;
|
|
||||||
private String endpoint;
|
|
||||||
private String user;
|
|
||||||
private Boolean logRequests;
|
|
||||||
private Boolean logResponses;
|
|
||||||
}
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
package dev.langchain4j.qianfan.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class LanguageModelProperties {
|
|
||||||
private String baseUrl;
|
|
||||||
private String apiKey;
|
|
||||||
private String secretKey;
|
|
||||||
private Double temperature;
|
|
||||||
private Integer maxRetries;
|
|
||||||
private Integer topK;
|
|
||||||
private Double topP;
|
|
||||||
private String modelName;
|
|
||||||
private String endpoint;
|
|
||||||
private Double penaltyScore;
|
|
||||||
private Boolean logRequests;
|
|
||||||
private Boolean logResponses;
|
|
||||||
}
|
|
||||||
@@ -1,29 +0,0 @@
|
|||||||
package dev.langchain4j.qianfan.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
|
||||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@ConfigurationProperties(prefix = Properties.PREFIX)
|
|
||||||
public class Properties {
|
|
||||||
|
|
||||||
static final String PREFIX = "langchain4j.qianfan";
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties chatModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties streamingChatModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
LanguageModelProperties languageModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
LanguageModelProperties streamingLanguageModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
EmbeddingModelProperties embeddingModel;
|
|
||||||
}
|
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
package dev.langchain4j.qianfan.spring;
|
|
||||||
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanChatModel;
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanLanguageModel;
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanStreamingChatModel;
|
|
||||||
import dev.langchain4j.model.qianfan.QianfanStreamingLanguageModel;
|
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
|
||||||
import org.springframework.context.annotation.Bean;
|
|
||||||
import org.springframework.context.annotation.Configuration;
|
|
||||||
|
|
||||||
import static dev.langchain4j.qianfan.spring.Properties.PREFIX;
|
|
||||||
|
|
||||||
@Configuration
|
|
||||||
@EnableConfigurationProperties(Properties.class)
|
|
||||||
public class QianfanAutoConfig {
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
|
||||||
QianfanChatModel qianfanChatModel(Properties properties) {
|
|
||||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
|
||||||
return QianfanChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
|
||||||
.apiKey(chatModelProperties.getApiKey())
|
|
||||||
.secretKey(chatModelProperties.getSecretKey())
|
|
||||||
.endpoint(chatModelProperties.getEndpoint())
|
|
||||||
.penaltyScore(chatModelProperties.getPenaltyScore())
|
|
||||||
.modelName(chatModelProperties.getModelName())
|
|
||||||
.temperature(chatModelProperties.getTemperature())
|
|
||||||
.topP(chatModelProperties.getTopP())
|
|
||||||
.responseFormat(chatModelProperties.getResponseFormat())
|
|
||||||
.maxRetries(chatModelProperties.getMaxRetries())
|
|
||||||
.logRequests(chatModelProperties.getLogRequests())
|
|
||||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
|
||||||
QianfanStreamingChatModel qianfanStreamingChatModel(Properties properties) {
|
|
||||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
|
||||||
return QianfanStreamingChatModel.builder().endpoint(chatModelProperties.getEndpoint())
|
|
||||||
.penaltyScore(chatModelProperties.getPenaltyScore())
|
|
||||||
.temperature(chatModelProperties.getTemperature())
|
|
||||||
.topP(chatModelProperties.getTopP()).baseUrl(chatModelProperties.getBaseUrl())
|
|
||||||
.apiKey(chatModelProperties.getApiKey())
|
|
||||||
.secretKey(chatModelProperties.getSecretKey())
|
|
||||||
.modelName(chatModelProperties.getModelName())
|
|
||||||
.responseFormat(chatModelProperties.getResponseFormat())
|
|
||||||
.logRequests(chatModelProperties.getLogRequests())
|
|
||||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".language-model.api-key")
|
|
||||||
QianfanLanguageModel qianfanLanguageModel(Properties properties) {
|
|
||||||
LanguageModelProperties languageModelProperties = properties.getLanguageModel();
|
|
||||||
return QianfanLanguageModel.builder().endpoint(languageModelProperties.getEndpoint())
|
|
||||||
.penaltyScore(languageModelProperties.getPenaltyScore())
|
|
||||||
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
|
|
||||||
.baseUrl(languageModelProperties.getBaseUrl())
|
|
||||||
.apiKey(languageModelProperties.getApiKey())
|
|
||||||
.secretKey(languageModelProperties.getSecretKey())
|
|
||||||
.modelName(languageModelProperties.getModelName())
|
|
||||||
.temperature(languageModelProperties.getTemperature())
|
|
||||||
.maxRetries(languageModelProperties.getMaxRetries())
|
|
||||||
.logRequests(languageModelProperties.getLogRequests())
|
|
||||||
.logResponses(languageModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key")
|
|
||||||
QianfanStreamingLanguageModel qianfanStreamingLanguageModel(Properties properties) {
|
|
||||||
LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel();
|
|
||||||
return QianfanStreamingLanguageModel.builder()
|
|
||||||
.endpoint(languageModelProperties.getEndpoint())
|
|
||||||
.penaltyScore(languageModelProperties.getPenaltyScore())
|
|
||||||
.topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP())
|
|
||||||
.baseUrl(languageModelProperties.getBaseUrl())
|
|
||||||
.apiKey(languageModelProperties.getApiKey())
|
|
||||||
.secretKey(languageModelProperties.getSecretKey())
|
|
||||||
.modelName(languageModelProperties.getModelName())
|
|
||||||
.temperature(languageModelProperties.getTemperature())
|
|
||||||
.maxRetries(languageModelProperties.getMaxRetries())
|
|
||||||
.logRequests(languageModelProperties.getLogRequests())
|
|
||||||
.logResponses(languageModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
|
||||||
QianfanEmbeddingModel qianfanEmbeddingModel(Properties properties) {
|
|
||||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
|
||||||
return QianfanEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
|
|
||||||
.endpoint(embeddingModelProperties.getEndpoint())
|
|
||||||
.apiKey(embeddingModelProperties.getApiKey())
|
|
||||||
.secretKey(embeddingModelProperties.getSecretKey())
|
|
||||||
.modelName(embeddingModelProperties.getModelName())
|
|
||||||
.user(embeddingModelProperties.getUser())
|
|
||||||
.maxRetries(embeddingModelProperties.getMaxRetries())
|
|
||||||
.logRequests(embeddingModelProperties.getLogRequests())
|
|
||||||
.logResponses(embeddingModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
package dev.langchain4j.zhipu.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class ChatModelProperties {
|
|
||||||
|
|
||||||
String baseUrl;
|
|
||||||
String apiKey;
|
|
||||||
Double temperature;
|
|
||||||
Double topP;
|
|
||||||
String modelName;
|
|
||||||
Integer maxRetries;
|
|
||||||
Integer maxToken;
|
|
||||||
Boolean logRequests;
|
|
||||||
Boolean logResponses;
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
package dev.langchain4j.zhipu.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
class EmbeddingModelProperties {
|
|
||||||
|
|
||||||
String baseUrl;
|
|
||||||
String apiKey;
|
|
||||||
String model;
|
|
||||||
Integer maxRetries;
|
|
||||||
Boolean logRequests;
|
|
||||||
Boolean logResponses;
|
|
||||||
}
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
package dev.langchain4j.zhipu.spring;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.springframework.boot.context.properties.ConfigurationProperties;
|
|
||||||
import org.springframework.boot.context.properties.NestedConfigurationProperty;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@ConfigurationProperties(prefix = Properties.PREFIX)
|
|
||||||
public class Properties {
|
|
||||||
|
|
||||||
static final String PREFIX = "langchain4j.zhipu";
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties chatModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
ChatModelProperties streamingChatModel;
|
|
||||||
|
|
||||||
@NestedConfigurationProperty
|
|
||||||
EmbeddingModelProperties embeddingModel;
|
|
||||||
}
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
package dev.langchain4j.zhipu.spring;
|
|
||||||
|
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
|
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
|
||||||
import dev.langchain4j.model.zhipu.ZhipuAiStreamingChatModel;
|
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
|
||||||
import org.springframework.context.annotation.Bean;
|
|
||||||
import org.springframework.context.annotation.Configuration;
|
|
||||||
|
|
||||||
import static dev.langchain4j.zhipu.spring.Properties.PREFIX;
|
|
||||||
|
|
||||||
@Configuration
|
|
||||||
@EnableConfigurationProperties(Properties.class)
|
|
||||||
public class ZhipuAutoConfig {
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".chat-model.api-key")
|
|
||||||
ZhipuAiChatModel zhipuAiChatModel(Properties properties) {
|
|
||||||
ChatModelProperties chatModelProperties = properties.getChatModel();
|
|
||||||
return ZhipuAiChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
|
||||||
.apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName())
|
|
||||||
.temperature(chatModelProperties.getTemperature())
|
|
||||||
.topP(chatModelProperties.getTopP()).maxRetries(chatModelProperties.getMaxRetries())
|
|
||||||
.maxToken(chatModelProperties.getMaxToken())
|
|
||||||
.logRequests(chatModelProperties.getLogRequests())
|
|
||||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key")
|
|
||||||
ZhipuAiStreamingChatModel zhipuStreamingChatModel(Properties properties) {
|
|
||||||
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
|
|
||||||
return ZhipuAiStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl())
|
|
||||||
.apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName())
|
|
||||||
.temperature(chatModelProperties.getTemperature())
|
|
||||||
.topP(chatModelProperties.getTopP()).maxToken(chatModelProperties.getMaxToken())
|
|
||||||
.logRequests(chatModelProperties.getLogRequests())
|
|
||||||
.logResponses(chatModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Bean
|
|
||||||
@ConditionalOnProperty(PREFIX + ".embedding-model.api-key")
|
|
||||||
ZhipuAiEmbeddingModel zhipuEmbeddingModel(Properties properties) {
|
|
||||||
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
|
|
||||||
return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl())
|
|
||||||
.apiKey(embeddingModelProperties.getApiKey())
|
|
||||||
.model(embeddingModelProperties.getModel())
|
|
||||||
.maxRetries(embeddingModelProperties.getMaxRetries())
|
|
||||||
.logRequests(embeddingModelProperties.getLogRequests())
|
|
||||||
.logResponses(embeddingModelProperties.getLogResponses()).build();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -8,20 +8,15 @@ import dev.langchain4j.data.embedding.Embedding;
|
|||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
import dev.langchain4j.provider.AzureModelFactory;
|
|
||||||
import dev.langchain4j.provider.DashscopeModelFactory;
|
|
||||||
import dev.langchain4j.provider.EmbeddingModelConstant;
|
import dev.langchain4j.provider.EmbeddingModelConstant;
|
||||||
import dev.langchain4j.provider.InMemoryModelFactory;
|
import dev.langchain4j.provider.InMemoryModelFactory;
|
||||||
import dev.langchain4j.provider.ModelProvider;
|
import dev.langchain4j.provider.ModelProvider;
|
||||||
import dev.langchain4j.provider.OpenAiModelFactory;
|
import dev.langchain4j.provider.OpenAiModelFactory;
|
||||||
import dev.langchain4j.provider.QianfanModelFactory;
|
|
||||||
import dev.langchain4j.provider.ZhipuModelFactory;
|
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.junit.jupiter.api.TestInstance;
|
import org.junit.jupiter.api.TestInstance;
|
||||||
|
|
||||||
import static org.junit.Assert.assertNotNull;
|
import static org.junit.Assert.assertNotNull;
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
|
||||||
|
|
||||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||||
@Disabled
|
@Disabled
|
||||||
@@ -40,65 +35,6 @@ public class ModelProviderTest extends BaseApplication {
|
|||||||
assertNotNull(response);
|
assertNotNull(response);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_qianfan_chat_model() {
|
|
||||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
|
||||||
modelConfig.setProvider(QianfanModelFactory.PROVIDER);
|
|
||||||
modelConfig.setModelName(QianfanModelFactory.DEFAULT_MODEL_NAME);
|
|
||||||
modelConfig.setBaseUrl(QianfanModelFactory.DEFAULT_BASE_URL);
|
|
||||||
modelConfig.setSecretKey(ParameterConfig.DEMO);
|
|
||||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
|
||||||
modelConfig.setEndpoint(QianfanModelFactory.DEFAULT_ENDPOINT);
|
|
||||||
|
|
||||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
|
||||||
assertThrows(RuntimeException.class, () -> {
|
|
||||||
chatModel.generate("hi");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_zhipu_chat_model() {
|
|
||||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
|
||||||
modelConfig.setProvider(ZhipuModelFactory.PROVIDER);
|
|
||||||
modelConfig.setModelName(ZhipuModelFactory.DEFAULT_MODEL_NAME);
|
|
||||||
modelConfig.setBaseUrl(ZhipuModelFactory.DEFAULT_BASE_URL);
|
|
||||||
modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5");
|
|
||||||
|
|
||||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
|
||||||
assertThrows(RuntimeException.class, () -> {
|
|
||||||
chatModel.generate("hi");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_dashscope_chat_model() {
|
|
||||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
|
||||||
modelConfig.setProvider(DashscopeModelFactory.PROVIDER);
|
|
||||||
modelConfig.setModelName(DashscopeModelFactory.DEFAULT_MODEL_NAME);
|
|
||||||
modelConfig.setBaseUrl(DashscopeModelFactory.DEFAULT_BASE_URL);
|
|
||||||
modelConfig.setEnableSearch(true);
|
|
||||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
|
||||||
|
|
||||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
|
||||||
assertThrows(RuntimeException.class, () -> {
|
|
||||||
chatModel.generate("hi");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_azure_chat_model() {
|
|
||||||
ChatModelConfig modelConfig = new ChatModelConfig();
|
|
||||||
modelConfig.setProvider(AzureModelFactory.PROVIDER);
|
|
||||||
modelConfig.setModelName(AzureModelFactory.DEFAULT_MODEL_NAME);
|
|
||||||
modelConfig.setBaseUrl(AzureModelFactory.DEFAULT_BASE_URL);
|
|
||||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
|
||||||
|
|
||||||
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
|
|
||||||
assertThrows(RuntimeException.class, () -> {
|
|
||||||
chatModel.generate("hi");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void test_in_memory_embedding_model() {
|
public void test_in_memory_embedding_model() {
|
||||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
||||||
@@ -122,61 +58,4 @@ public class ModelProviderTest extends BaseApplication {
|
|||||||
Response<Embedding> embed = embeddingModel.embed("hi");
|
Response<Embedding> embed = embeddingModel.embed("hi");
|
||||||
assertNotNull(embed);
|
assertNotNull(embed);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_azure_embedding_model() {
|
|
||||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
|
||||||
modelConfig.setProvider(AzureModelFactory.PROVIDER);
|
|
||||||
modelConfig.setModelName(AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
|
||||||
modelConfig.setBaseUrl(AzureModelFactory.DEFAULT_BASE_URL);
|
|
||||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
|
||||||
|
|
||||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
|
||||||
assertThrows(RuntimeException.class, () -> {
|
|
||||||
embeddingModel.embed("hi");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_dashscope_embedding_model() {
|
|
||||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
|
||||||
modelConfig.setProvider(DashscopeModelFactory.PROVIDER);
|
|
||||||
modelConfig.setModelName(DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
|
||||||
modelConfig.setBaseUrl(DashscopeModelFactory.DEFAULT_BASE_URL);
|
|
||||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
|
||||||
|
|
||||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
|
||||||
assertThrows(RuntimeException.class, () -> {
|
|
||||||
embeddingModel.embed("hi");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_qianfan_embedding_model() {
|
|
||||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
|
||||||
modelConfig.setProvider(QianfanModelFactory.PROVIDER);
|
|
||||||
modelConfig.setModelName(QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
|
||||||
modelConfig.setBaseUrl(QianfanModelFactory.DEFAULT_BASE_URL);
|
|
||||||
modelConfig.setApiKey(ParameterConfig.DEMO);
|
|
||||||
modelConfig.setSecretKey(ParameterConfig.DEMO);
|
|
||||||
|
|
||||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
|
||||||
assertThrows(RuntimeException.class, () -> {
|
|
||||||
embeddingModel.embed("hi");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_zhipu_embedding_model() {
|
|
||||||
EmbeddingModelConfig modelConfig = new EmbeddingModelConfig();
|
|
||||||
modelConfig.setProvider(ZhipuModelFactory.PROVIDER);
|
|
||||||
modelConfig.setModelName(ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME);
|
|
||||||
modelConfig.setBaseUrl(ZhipuModelFactory.DEFAULT_BASE_URL);
|
|
||||||
modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5");
|
|
||||||
|
|
||||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
|
|
||||||
assertThrows(RuntimeException.class, () -> {
|
|
||||||
embeddingModel.embed("hi");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
25
pom.xml
25
pom.xml
@@ -146,31 +146,11 @@
|
|||||||
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
|
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
|
||||||
<version>${langchain4j.embedding.version}</version>
|
<version>${langchain4j.embedding.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-azure-open-ai</artifactId>
|
|
||||||
<version>${langchain4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
|
||||||
<version>${langchain4j.embedding.version}</version>
|
<version>${langchain4j.embedding.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-qianfan</artifactId>
|
|
||||||
<version>${langchain4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-zhipu-ai</artifactId>
|
|
||||||
<version>${langchain4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-dashscope</artifactId>
|
|
||||||
<version>${langchain4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-milvus</artifactId>
|
<artifactId>langchain4j-milvus</artifactId>
|
||||||
@@ -186,11 +166,6 @@
|
|||||||
<artifactId>langchain4j-pgvector</artifactId>
|
<artifactId>langchain4j-pgvector</artifactId>
|
||||||
<version>${langchain4j.version}</version>
|
<version>${langchain4j.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>dev.langchain4j</groupId>
|
|
||||||
<artifactId>langchain4j-chatglm</artifactId>
|
|
||||||
<version>${langchain4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-ollama</artifactId>
|
<artifactId>langchain4j-ollama</artifactId>
|
||||||
|
|||||||
@@ -81,6 +81,7 @@ export enum MsgContentTypeEnum {
|
|||||||
METRIC_TREND = 'METRIC_TREND',
|
METRIC_TREND = 'METRIC_TREND',
|
||||||
METRIC_BAR = 'METRIC_BAR',
|
METRIC_BAR = 'METRIC_BAR',
|
||||||
MARKDOWN = 'MARKDOWN',
|
MARKDOWN = 'MARKDOWN',
|
||||||
|
METRIC_PIE = 'METRIC_PIE',
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum ChatContextTypeQueryTypeEnum {
|
export enum ChatContextTypeQueryTypeEnum {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import Loading from './Loading';
|
|||||||
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
|
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
|
||||||
import { solarizedlight } from 'react-syntax-highlighter/dist/esm/styles/prism';
|
import { solarizedlight } from 'react-syntax-highlighter/dist/esm/styles/prism';
|
||||||
import React, { ReactNode, useState } from 'react';
|
import React, { ReactNode, useState } from 'react';
|
||||||
|
import ReactMarkdown from 'react-markdown';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
queryId?: number;
|
queryId?: number;
|
||||||
@@ -122,9 +123,11 @@ const ExecuteItem: React.FC<Props> = ({
|
|||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div>
|
<div>
|
||||||
{[MsgContentTypeEnum.METRIC_TREND, MsgContentTypeEnum.METRIC_BAR].includes(
|
{[
|
||||||
msgContentType as MsgContentTypeEnum
|
MsgContentTypeEnum.METRIC_TREND,
|
||||||
) && (
|
MsgContentTypeEnum.METRIC_BAR,
|
||||||
|
MsgContentTypeEnum.METRIC_PIE,
|
||||||
|
].includes(msgContentType as MsgContentTypeEnum) && (
|
||||||
<Switch
|
<Switch
|
||||||
checkedChildren="表格"
|
checkedChildren="表格"
|
||||||
unCheckedChildren="表格"
|
unCheckedChildren="表格"
|
||||||
@@ -151,7 +154,7 @@ const ExecuteItem: React.FC<Props> = ({
|
|||||||
{data.textSummary && (
|
{data.textSummary && (
|
||||||
<p className={`${prefixCls}-step-title`}>
|
<p className={`${prefixCls}-step-title`}>
|
||||||
<span style={{ marginRight: 5 }}>总结:</span>
|
<span style={{ marginRight: 5 }}>总结:</span>
|
||||||
{data.textSummary}
|
<ReactMarkdown>{data.textSummary}</ReactMarkdown>
|
||||||
</p>
|
</p>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import { useExportByEcharts } from '../../../hooks';
|
|||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
data: MsgDataType;
|
data: MsgDataType;
|
||||||
question: string;
|
question?: string;
|
||||||
triggerResize?: boolean;
|
triggerResize?: boolean;
|
||||||
loading: boolean;
|
loading: boolean;
|
||||||
metricField: ColumnType;
|
metricField: ColumnType;
|
||||||
@@ -32,7 +32,7 @@ type Props = {
|
|||||||
|
|
||||||
const BarChart: React.FC<Props> = ({
|
const BarChart: React.FC<Props> = ({
|
||||||
data,
|
data,
|
||||||
question,
|
question="",
|
||||||
triggerResize,
|
triggerResize,
|
||||||
loading,
|
loading,
|
||||||
metricField,
|
metricField,
|
||||||
|
|||||||
123
webapp/packages/chat-sdk/src/components/ChatMsg/Pie/PieChart.tsx
Normal file
123
webapp/packages/chat-sdk/src/components/ChatMsg/Pie/PieChart.tsx
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import { PREFIX_CLS, THEME_COLOR_LIST } from '../../../common/constants';
|
||||||
|
import { MsgDataType } from '../../../common/type';
|
||||||
|
import { formatByDecimalPlaces, getFormattedValue } from '../../../utils/utils';
|
||||||
|
import type { ECharts } from 'echarts';
|
||||||
|
import * as echarts from 'echarts';
|
||||||
|
import { useEffect, useRef } from 'react';
|
||||||
|
import { ColumnType } from '../../../common/type';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
data: MsgDataType;
|
||||||
|
metricField: ColumnType;
|
||||||
|
categoryField: ColumnType;
|
||||||
|
triggerResize?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
const PieChart: React.FC<Props> = ({
|
||||||
|
data,
|
||||||
|
metricField,
|
||||||
|
categoryField,
|
||||||
|
triggerResize,
|
||||||
|
}) => {
|
||||||
|
const chartRef = useRef<any>();
|
||||||
|
const instanceRef = useRef<ECharts>();
|
||||||
|
|
||||||
|
const { queryResults } = data;
|
||||||
|
const categoryColumnName = categoryField?.bizName || '';
|
||||||
|
const metricColumnName = metricField?.bizName || '';
|
||||||
|
|
||||||
|
const renderChart = () => {
|
||||||
|
let instanceObj: any;
|
||||||
|
if (!instanceRef.current) {
|
||||||
|
instanceObj = echarts.init(chartRef.current);
|
||||||
|
instanceRef.current = instanceObj;
|
||||||
|
} else {
|
||||||
|
instanceObj = instanceRef.current;
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = queryResults || [];
|
||||||
|
const seriesData = data.map((item, index) => {
|
||||||
|
const value = item[metricColumnName];
|
||||||
|
const name = item[categoryColumnName] !== undefined ? item[categoryColumnName] : '未知';
|
||||||
|
return {
|
||||||
|
name,
|
||||||
|
value,
|
||||||
|
itemStyle: {
|
||||||
|
color: THEME_COLOR_LIST[index % THEME_COLOR_LIST.length],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
instanceObj.setOption({
|
||||||
|
tooltip: {
|
||||||
|
trigger: 'item',
|
||||||
|
formatter: function (params: any) {
|
||||||
|
const value = params.value;
|
||||||
|
return `${params.name}: ${
|
||||||
|
metricField.dataFormatType === 'percent'
|
||||||
|
? `${formatByDecimalPlaces(
|
||||||
|
metricField.dataFormat?.needMultiply100 ? +value * 100 : value,
|
||||||
|
metricField.dataFormat?.decimalPlaces || 2
|
||||||
|
)}%`
|
||||||
|
: getFormattedValue(value)
|
||||||
|
}`;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
legend: {
|
||||||
|
orient: 'vertical',
|
||||||
|
left: 'left',
|
||||||
|
type: 'scroll',
|
||||||
|
data: seriesData.map(item => item.name),
|
||||||
|
selectedMode: true,
|
||||||
|
textStyle: {
|
||||||
|
color: '#666',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
series: [
|
||||||
|
{
|
||||||
|
name: '占比',
|
||||||
|
type: 'pie',
|
||||||
|
radius: ['40%', '70%'],
|
||||||
|
avoidLabelOverlap: false,
|
||||||
|
itemStyle: {
|
||||||
|
borderRadius: 10,
|
||||||
|
borderColor: '#fff',
|
||||||
|
borderWidth: 2,
|
||||||
|
},
|
||||||
|
label: {
|
||||||
|
show: false,
|
||||||
|
position: 'center',
|
||||||
|
},
|
||||||
|
emphasis: {
|
||||||
|
label: {
|
||||||
|
show: true,
|
||||||
|
fontSize: '14',
|
||||||
|
fontWeight: 'bold',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
labelLine: {
|
||||||
|
show: false,
|
||||||
|
},
|
||||||
|
data: seriesData,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
instanceObj.resize();
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (queryResults && queryResults.length > 0) {
|
||||||
|
renderChart();
|
||||||
|
}
|
||||||
|
}, [queryResults, metricField, categoryField]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (triggerResize && instanceRef.current) {
|
||||||
|
instanceRef.current.resize();
|
||||||
|
}
|
||||||
|
}, [triggerResize]);
|
||||||
|
|
||||||
|
return <div className={`${PREFIX_CLS}-pie-chart`} ref={chartRef} />;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default PieChart;
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
import { PREFIX_CLS } from '../../../common/constants';
|
||||||
|
import { MsgDataType } from '../../../common/type';
|
||||||
|
import { useRef, useState } from 'react';
|
||||||
|
import NoPermissionChart from '../NoPermissionChart';
|
||||||
|
import { ColumnType } from '../../../common/type';
|
||||||
|
import { Spin, Select } from 'antd';
|
||||||
|
import PieChart from './PieChart';
|
||||||
|
import Bar from '../Bar';
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
data: MsgDataType;
|
||||||
|
question: string;
|
||||||
|
triggerResize?: boolean;
|
||||||
|
loading: boolean;
|
||||||
|
metricField: ColumnType;
|
||||||
|
categoryField: ColumnType;
|
||||||
|
onApplyAuth?: (model: string) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
const metricChartSelectOptions = [
|
||||||
|
{
|
||||||
|
value: 'pie',
|
||||||
|
label: '饼图',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
value: 'bar',
|
||||||
|
label: '柱状图',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
const Pie: React.FC<Props> = ({
|
||||||
|
data,
|
||||||
|
question,
|
||||||
|
triggerResize,
|
||||||
|
loading,
|
||||||
|
metricField,
|
||||||
|
categoryField,
|
||||||
|
onApplyAuth,
|
||||||
|
}) => {
|
||||||
|
const [chartType, setChartType] = useState('pie');
|
||||||
|
const { entityInfo } = data;
|
||||||
|
|
||||||
|
if (metricField && !metricField?.authorized) {
|
||||||
|
return (
|
||||||
|
<NoPermissionChart
|
||||||
|
model={entityInfo?.dataSetInfo?.name || ''}
|
||||||
|
chartType="pieChart"
|
||||||
|
onApplyAuth={onApplyAuth}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const prefixCls = `${PREFIX_CLS}-pie`;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className={prefixCls}>
|
||||||
|
<div className={`${prefixCls}-metric-fields ${prefixCls}-metric-field-single`}>
|
||||||
|
{question}
|
||||||
|
</div>
|
||||||
|
<div className={`${prefixCls}-select-options`}>
|
||||||
|
<Select
|
||||||
|
defaultValue="pie"
|
||||||
|
bordered={false}
|
||||||
|
options={metricChartSelectOptions}
|
||||||
|
onChange={(value: string) => setChartType(value)}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
{chartType === 'pie' ? (
|
||||||
|
<PieChart
|
||||||
|
data={data}
|
||||||
|
metricField={metricField}
|
||||||
|
categoryField={categoryField}
|
||||||
|
triggerResize={triggerResize}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
|
<Bar
|
||||||
|
data={data}
|
||||||
|
triggerResize={triggerResize}
|
||||||
|
loading={loading}
|
||||||
|
metricField={metricField}
|
||||||
|
onApplyAuth={onApplyAuth}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default Pie;
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
@import '../../../styles/index.less';
|
||||||
|
|
||||||
|
@pie-prefix-cls: ~'@{prefix-cls}-pie';
|
||||||
|
|
||||||
|
.@{pie-prefix-cls} {
|
||||||
|
&-select-options {
|
||||||
|
display: flex;
|
||||||
|
justify-content: flex-end;
|
||||||
|
}
|
||||||
|
|
||||||
|
&-chart {
|
||||||
|
width: 100%;
|
||||||
|
height: 400px;
|
||||||
|
min-height: 400px;
|
||||||
|
}
|
||||||
|
|
||||||
|
&-metric-fields {
|
||||||
|
display: flex;
|
||||||
|
align-items: baseline;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
column-gap: 8px;
|
||||||
|
row-gap: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
&-metric-field-single {
|
||||||
|
padding-left: 0;
|
||||||
|
font-weight: 500;
|
||||||
|
cursor: default;
|
||||||
|
font-size: 15px;
|
||||||
|
color: var(--text-color);
|
||||||
|
|
||||||
|
&:hover {
|
||||||
|
color: var(--text-color);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
&-indicator-name {
|
||||||
|
font-size: 14px;
|
||||||
|
color: var(--text-color);
|
||||||
|
font-weight: 500;
|
||||||
|
margin-top: 2px;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,6 +12,7 @@ import Text from './Text';
|
|||||||
import DrillDownDimensions from '../DrillDownDimensions';
|
import DrillDownDimensions from '../DrillDownDimensions';
|
||||||
import MetricOptions from '../MetricOptions';
|
import MetricOptions from '../MetricOptions';
|
||||||
import { isMobile } from '../../utils/utils';
|
import { isMobile } from '../../utils/utils';
|
||||||
|
import Pie from './Pie';
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
queryId?: number;
|
queryId?: number;
|
||||||
@@ -120,6 +121,16 @@ const ChatMsg: React.FC<Props> = ({
|
|||||||
return MsgContentTypeEnum.METRIC_TREND;
|
return MsgContentTypeEnum.METRIC_TREND;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const isMetricPie =
|
||||||
|
metricFields.length > 0 &&
|
||||||
|
metricFields?.length === 1 &&
|
||||||
|
(isMobile ? dataSource?.length <= 5 : dataSource?.length <= 10) &&
|
||||||
|
dataSource.every(item => item[metricFields[0].bizName] > 0);
|
||||||
|
|
||||||
|
if (isMetricPie) {
|
||||||
|
return MsgContentTypeEnum.METRIC_PIE;
|
||||||
|
}
|
||||||
|
|
||||||
const isMetricBar =
|
const isMetricBar =
|
||||||
categoryField?.length > 0 &&
|
categoryField?.length > 0 &&
|
||||||
metricFields?.length === 1 &&
|
metricFields?.length === 1 &&
|
||||||
@@ -148,7 +159,7 @@ const ChatMsg: React.FC<Props> = ({
|
|||||||
[queryColumns.length > 5 ? 'width' : 'minWidth']: queryColumns.length * 150,
|
[queryColumns.length > 5 ? 'width' : 'minWidth']: queryColumns.length * 150,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
if (type === MsgContentTypeEnum.METRIC_TREND) {
|
if (type === MsgContentTypeEnum.METRIC_TREND || type === MsgContentTypeEnum.METRIC_PIE) {
|
||||||
return { width: 'calc(100vw - 410px)' };
|
return { width: 'calc(100vw - 410px)' };
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -213,6 +224,21 @@ const ChatMsg: React.FC<Props> = ({
|
|||||||
metricField={metricFields[0]}
|
metricField={metricFields[0]}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
|
case MsgContentTypeEnum.METRIC_PIE:
|
||||||
|
const categoryField = columns.find(item => item.showType === 'CATEGORY');
|
||||||
|
if (!categoryField) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<Pie
|
||||||
|
data={{ ...data, queryColumns: columns, queryResults: dataSource }}
|
||||||
|
question={question}
|
||||||
|
triggerResize={triggerResize}
|
||||||
|
loading={loading}
|
||||||
|
metricField={metricFields[0]}
|
||||||
|
categoryField={categoryField}
|
||||||
|
/>
|
||||||
|
);
|
||||||
case MsgContentTypeEnum.MARKDOWN:
|
case MsgContentTypeEnum.MARKDOWN:
|
||||||
return (
|
return (
|
||||||
<div style={{ maxHeight: 800 }}>
|
<div style={{ maxHeight: 800 }}>
|
||||||
|
|||||||
@@ -16,6 +16,8 @@
|
|||||||
|
|
||||||
@import "../components/ChatMsg/MetricTrend/style.less";
|
@import "../components/ChatMsg/MetricTrend/style.less";
|
||||||
|
|
||||||
|
@import "../components/ChatMsg/Pie/style.less";
|
||||||
|
|
||||||
@import "../components/ChatMsg/ApplyAuth/style.less";
|
@import "../components/ChatMsg/ApplyAuth/style.less";
|
||||||
|
|
||||||
@import "../components/ChatMsg/NoPermissionChart/style.less";
|
@import "../components/ChatMsg/NoPermissionChart/style.less";
|
||||||
|
|||||||
Reference in New Issue
Block a user