(improvement)(chat) Support large models qianfan, zhipu, Azure, LocalAi, Dashscope, and handle the apiKey configuration as hidden. (#1552)

This commit is contained in:
lexluo09
2024-08-11 23:28:24 +08:00
committed by GitHub
parent 2f2f493d17
commit 8b01dac8d4
12 changed files with 315 additions and 184 deletions

View File

@@ -4,8 +4,15 @@ import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter; import com.tencent.supersonic.common.pojo.Parameter;
import dev.ai4j.openai4j.chat.ChatCompletionModel;
import dev.langchain4j.model.dashscope.QwenModelName;
import dev.langchain4j.provider.AzureModelFactory;
import dev.langchain4j.provider.DashscopeModelFactory;
import dev.langchain4j.provider.LocalAiModelFactory;
import dev.langchain4j.provider.OllamaModelFactory; import dev.langchain4j.provider.OllamaModelFactory;
import dev.langchain4j.provider.OpenAiModelFactory; import dev.langchain4j.provider.OpenAiModelFactory;
import dev.langchain4j.provider.QianfanModelFactory;
import dev.langchain4j.provider.ZhipuModelFactory;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -18,49 +25,36 @@ public class ChatModelParameterConfig extends ParameterConfig {
public static final Parameter CHAT_MODEL_PROVIDER = public static final Parameter CHAT_MODEL_PROVIDER =
new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER, new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER,
"接口协议", "", "接口协议", "", "list",
"list", "对话模型配置", "对话模型配置", getCandidateValues());
getCandidateValues());
public static final Parameter CHAT_MODEL_BASE_URL = public static final Parameter CHAT_MODEL_BASE_URL =
new Parameter("s2.chat.model.base.url", "https://api.openai.com/v1", new Parameter("s2.chat.model.base.url", OpenAiModelFactory.DEFAULT_BASE_URL,
"BaseUrl", "", "string", "BaseUrl", "", "string",
"对话模型配置", null, "对话模型配置", null, getBaseUrlDependency());
getDependency(CHAT_MODEL_PROVIDER.getName(), public static final Parameter CHAT_MODEL_ENDPOINT =
getCandidateValues(), new Parameter("s2.chat.model.endpoint", "llama_2_70b",
ImmutableMap.of( "Endpoint", "", "string",
OpenAiModelFactory.PROVIDER, "https://api.openai.com/v1", "对话模型配置", null, getEndpointDependency());
OllamaModelFactory.PROVIDER, "http://localhost:11434")
)
);
public static final Parameter CHAT_MODEL_API_KEY = public static final Parameter CHAT_MODEL_API_KEY =
new Parameter("s2.chat.model.api.key", "demo", new Parameter("s2.chat.model.api.key", DEMO,
"ApiKey", "", "ApiKey", "", "password",
"string", "对话模型配置", null, "对话模型配置", null, getApiKeyDependency()
getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "demo"))
); );
public static final Parameter CHAT_MODEL_SECRET_KEY =
new Parameter("s2.chat.model.secretKey", "demo",
"SecretKey", "", "password",
"对话模型配置", null, getSecretKeyDependency());
public static final Parameter CHAT_MODEL_NAME = public static final Parameter CHAT_MODEL_NAME =
new Parameter("s2.chat.model.name", "gpt-3.5-turbo", new Parameter("s2.chat.model.name", "gpt-3.5-turbo",
"ModelName", "", "ModelName", "", "string",
"string", "对话模型配置", null, "对话模型配置", null, getModelNameDependency());
getDependency(CHAT_MODEL_PROVIDER.getName(),
getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
OllamaModelFactory.PROVIDER, "qwen:0.5b")
));
public static final Parameter CHAT_MODEL_TEMPERATURE = public static final Parameter CHAT_MODEL_TEMPERATURE =
new Parameter("s2.chat.model.temperature", "0.0", new Parameter("s2.chat.model.temperature", "0.0",
"Temperature", "", "Temperature", "",
"slider", "对话模型配置", null, "slider", "对话模型配置");
getDependency(CHAT_MODEL_PROVIDER.getName(),
getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "0.0", OllamaModelFactory.PROVIDER, "0.0")));
public static final Parameter CHAT_MODEL_TIMEOUT = public static final Parameter CHAT_MODEL_TIMEOUT =
new Parameter("s2.chat.model.timeout", "60", new Parameter("s2.chat.model.timeout", "60",
@@ -70,8 +64,9 @@ public class ChatModelParameterConfig extends ParameterConfig {
@Override @Override
public List<Parameter> getSysParameters() { public List<Parameter> getSysParameters() {
return Lists.newArrayList( return Lists.newArrayList(
CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_API_KEY, CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
CHAT_MODEL_NAME, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT
); );
} }
@@ -82,6 +77,8 @@ public class ChatModelParameterConfig extends ParameterConfig {
String chatModelName = getParameterValue(CHAT_MODEL_NAME); String chatModelName = getParameterValue(CHAT_MODEL_NAME);
String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE); String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE);
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT); String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT);
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
return ChatModelConfig.builder() return ChatModelConfig.builder()
.provider(chatModelProvider) .provider(chatModelProvider)
@@ -90,10 +87,94 @@ public class ChatModelParameterConfig extends ParameterConfig {
.modelName(chatModelName) .modelName(chatModelName)
.temperature(Double.valueOf(chatModelTemperature)) .temperature(Double.valueOf(chatModelTemperature))
.timeOut(Long.valueOf(chatModelTimeout)) .timeOut(Long.valueOf(chatModelTimeout))
.endpoint(endpoint)
.secretKey(secretKey)
.build(); .build();
} }
private static ArrayList<String> getCandidateValues() { private static List<String> getCandidateValues() {
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER); List<String> candidateValues = getBaseUrlCandidateValues();
candidateValues.add(AzureModelFactory.PROVIDER);
return candidateValues;
}
private static ArrayList<String> getBaseUrlCandidateValues() {
return Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER,
LocalAiModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER);
}
private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
getBaseUrlCandidateValues(),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL,
LocalAiModelFactory.PROVIDER, LocalAiModelFactory.DEFAULT_BASE_URL,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL)
);
}
private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER,
LocalAiModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER
),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, DEMO,
QianfanModelFactory.PROVIDER, DEMO,
ZhipuModelFactory.PROVIDER, DEMO,
LocalAiModelFactory.PROVIDER, DEMO,
AzureModelFactory.PROVIDER, DEMO,
DashscopeModelFactory.PROVIDER, DEMO
));
}
private static List<Parameter.Dependency> getModelNameDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
getCandidateValues(),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
OllamaModelFactory.PROVIDER, "qwen:0.5b",
QianfanModelFactory.PROVIDER, "Llama-2-70b-chat",
ZhipuModelFactory.PROVIDER, ChatCompletionModel.GPT_4.toString(),
LocalAiModelFactory.PROVIDER, "ggml-gpt4all-j",
AzureModelFactory.PROVIDER, "gpt-35-turbo",
DashscopeModelFactory.PROVIDER, QwenModelName.QWEN_PLUS
)
);
}
private static List<Parameter.Dependency> getEndpointDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(
AzureModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER
),
ImmutableMap.of(
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
QianfanModelFactory.PROVIDER, "llama_2_70b"
)
);
}
private static List<Parameter.Dependency> getSecretKeyDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER),
ImmutableMap.of(
QianfanModelFactory.PROVIDER, DEMO
)
);
} }
} }

View File

@@ -21,109 +21,34 @@ import java.util.List;
@Service("EmbeddingModelParameterConfig") @Service("EmbeddingModelParameterConfig")
@Slf4j @Slf4j
public class EmbeddingModelParameterConfig extends ParameterConfig { public class EmbeddingModelParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_MODEL_PROVIDER = public static final Parameter EMBEDDING_MODEL_PROVIDER =
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
"接口协议", "", "接口协议", "", "list",
"list", "向量模型配置", "向量模型配置", getCandidateValues());
getCandidateValues());
public static final Parameter EMBEDDING_MODEL_BASE_URL = public static final Parameter EMBEDDING_MODEL_BASE_URL =
new Parameter("s2.embedding.model.base.url", "", new Parameter("s2.embedding.model.base.url", "",
"BaseUrl", "", "BaseUrl", "", "string",
"string", "向量模型配置", null, "向量模型配置", null, getBaseUrlDependency()
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER
),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, "https://api.openai.com/v1",
OllamaModelFactory.PROVIDER, "http://localhost:11434",
AzureModelFactory.PROVIDER, "https://xxxx.openai.azure.com/",
DashscopeModelFactory.PROVIDER, "https://dashscope.aliyuncs.com/compatible-mode/v1",
QianfanModelFactory.PROVIDER, "https://aip.baidubce.com",
ZhipuModelFactory.PROVIDER, "https://open.bigmodel.cn/api/paas/v4/"
)
)
); );
public static final Parameter EMBEDDING_MODEL_API_KEY = public static final Parameter EMBEDDING_MODEL_API_KEY =
new Parameter("s2.embedding.model.api.key", "", new Parameter("s2.embedding.model.api.key", "",
"ApiKey", "", "ApiKey", "", "password",
"string", "向量模型配置", null, "向量模型配置", null, getApiKeyDependency());
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER
),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, "demo",
OllamaModelFactory.PROVIDER, "demo",
AzureModelFactory.PROVIDER, "demo",
DashscopeModelFactory.PROVIDER, "demo",
QianfanModelFactory.PROVIDER, "demo",
ZhipuModelFactory.PROVIDER, "demo"
)
));
public static final Parameter EMBEDDING_MODEL_NAME = public static final Parameter EMBEDDING_MODEL_NAME =
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH, new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
"ModelName", "", "ModelName", "", "string",
"string", "向量模型配置", null, "向量模型配置", null, getModelNameDependency());
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
InMemoryModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER
),
ImmutableMap.of(
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
OllamaModelFactory.PROVIDER, "all-minilm",
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
DashscopeModelFactory.PROVIDER, "text-embedding-ada-002",
QianfanModelFactory.PROVIDER, "text-embedding-ada-002",
ZhipuModelFactory.PROVIDER, "text-embedding-ada-002"
)
));
public static final Parameter EMBEDDING_MODEL_PATH = public static final Parameter EMBEDDING_MODEL_PATH =
new Parameter("s2.embedding.model.path", "", new Parameter("s2.embedding.model.path", "",
"模型路径", "", "模型路径", "", "string",
"string", "向量模型配置", null, "向量模型配置", null, getModelPathDependency());
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
InMemoryModelFactory.PROVIDER
),
ImmutableMap.of(
InMemoryModelFactory.PROVIDER, ""
)
));
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH = public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
new Parameter("s2.embedding.model.vocabulary.path", "", new Parameter("s2.embedding.model.vocabulary.path", "",
"词汇表路径", "", "词汇表路径", "", "string",
"string", "向量模型配置", null, "向量模型配置", null, getModelPathDependency());
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
InMemoryModelFactory.PROVIDER
),
ImmutableMap.of(
InMemoryModelFactory.PROVIDER, ""
)));
@Override @Override
public List<Parameter> getSysParameters() { public List<Parameter> getSysParameters() {
@@ -152,13 +77,80 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
} }
private static ArrayList<String> getCandidateValues() { private static ArrayList<String> getCandidateValues() {
return Lists.newArrayList(InMemoryModelFactory.PROVIDER, return Lists.newArrayList(
InMemoryModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER); ZhipuModelFactory.PROVIDER
);
} }
private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_BASE_URL
)
);
}
private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO,
OllamaModelFactory.PROVIDER, DEMO,
AzureModelFactory.PROVIDER, DEMO,
DashscopeModelFactory.PROVIDER, DEMO,
QianfanModelFactory.PROVIDER, DEMO,
ZhipuModelFactory.PROVIDER, DEMO)
);
}
private static List<Parameter.Dependency> getModelNameDependency() {
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
InMemoryModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER
),
ImmutableMap.of(
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
OllamaModelFactory.PROVIDER, "all-minilm",
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
DashscopeModelFactory.PROVIDER, "text-embedding-ada-002",
QianfanModelFactory.PROVIDER, "text-embedding-ada-002",
ZhipuModelFactory.PROVIDER, "text-embedding-ada-002"
)
);
}
private static List<Parameter.Dependency> getModelPathDependency() {
return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(InMemoryModelFactory.PROVIDER),
ImmutableMap.of(InMemoryModelFactory.PROVIDER, "")
);
}
} }

View File

@@ -9,6 +9,7 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List; import java.util.List;
@Service("EmbeddingStoreParameterConfig") @Service("EmbeddingStoreParameterConfig")
@@ -16,50 +17,23 @@ import java.util.List;
public class EmbeddingStoreParameterConfig extends ParameterConfig { public class EmbeddingStoreParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_STORE_PROVIDER = public static final Parameter EMBEDDING_STORE_PROVIDER =
new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(),
"向量库类型", "", "向量库类型", "", "list",
"list", "向量库配置", "向量库配置", getCandidateValues());
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name()));
public static final Parameter EMBEDDING_STORE_BASE_URL = public static final Parameter EMBEDDING_STORE_BASE_URL =
new Parameter("s2.embedding.store.base.url", "", new Parameter("s2.embedding.store.base.url", "",
"BaseUrl", "", "BaseUrl", "", "string",
"string", "向量库配置", null, "向量库配置", null, getBaseUrlDependency());
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name()
),
ImmutableMap.of(
EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"
)
));
public static final Parameter EMBEDDING_STORE_API_KEY = public static final Parameter EMBEDDING_STORE_API_KEY =
new Parameter("s2.embedding.store.api.key", "", new Parameter("s2.embedding.store.api.key", "",
"ApiKey", "", "ApiKey", "", "password",
"string", "向量库配置", null, "向量库配置", null, getApiKeyDependency());
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.MILVUS.name()
),
ImmutableMap.of(
EmbeddingStoreType.MILVUS.name(), "demo"
)
));
public static final Parameter EMBEDDING_STORE_PERSIST_PATH = public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
new Parameter("s2.embedding.store.persist.path", "/tmp", new Parameter("s2.embedding.store.persist.path", "/tmp",
"持久化路径", "", "持久化路径", "", "string",
"string", "向量库配置", null, "向量库配置", null, getPathDependency());
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.IN_MEMORY.name()
),
ImmutableMap.of(
EmbeddingStoreType.IN_MEMORY.name(), "/tmp"
)));
public static final Parameter EMBEDDING_STORE_TIMEOUT = public static final Parameter EMBEDDING_STORE_TIMEOUT =
new Parameter("s2.embedding.store.timeout", "60", new Parameter("s2.embedding.store.timeout", "60",
@@ -68,16 +42,8 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_STORE_DIMENSION = public static final Parameter EMBEDDING_STORE_DIMENSION =
new Parameter("s2.embedding.store.dimension", "", new Parameter("s2.embedding.store.dimension", "",
"纬度", "", "纬度", "", "number",
"number", "向量库配置", null, "向量库配置", null, getDimensionDependency());
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.MILVUS.name()
),
ImmutableMap.of(
EmbeddingStoreType.MILVUS.name(), "384"
)
));
@Override @Override
public List<Parameter> getSysParameters() { public List<Parameter> getSysParameters() {
@@ -97,13 +63,50 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) { if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) {
dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION)); dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION));
} }
return EmbeddingStoreConfig.builder() return EmbeddingStoreConfig.builder().provider(provider)
.provider(provider) .baseUrl(baseUrl).apiKey(apiKey).persistPath(persistPath)
.baseUrl(baseUrl) .timeOut(Long.valueOf(timeOut)).dimension(dimension).build();
.apiKey(apiKey) }
.persistPath(persistPath)
.timeOut(Long.valueOf(timeOut)) private static ArrayList<String> getCandidateValues() {
.dimension(dimension) return Lists.newArrayList(
.build(); EmbeddingStoreType.IN_MEMORY.name(),
EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name());
}
private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name()),
ImmutableMap.of(
EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"
)
);
}
private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO)
);
}
private static List<Parameter.Dependency> getPathDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name()),
ImmutableMap.of(EmbeddingStoreType.IN_MEMORY.name(), "/tmp"));
}
private static List<Parameter.Dependency> getDimensionDependency() {
return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.MILVUS.name()
),
ImmutableMap.of(
EmbeddingStoreType.MILVUS.name(), "384"
)
);
} }
} }

View File

@@ -14,7 +14,7 @@ import java.util.Map;
@Service @Service
public abstract class ParameterConfig { public abstract class ParameterConfig {
public static final String DEMO = "demo";
@Autowired @Autowired
private SystemConfigService sysConfigService; private SystemConfigService sysConfigService;

View File

@@ -20,6 +20,12 @@ public class ChatModelConfig implements Serializable {
private String modelName; private String modelName;
private Double temperature = 0.0d; private Double temperature = 0.0d;
private Long timeOut = 60L; private Long timeOut = 60L;
private String endpoint;
private String secretKey;
private Double topP;
private Integer maxRetries = 3;
private Boolean logRequests = false;
private Boolean logResponses = false;
public String keyDecrypt() { public String keyDecrypt() {
return AESEncryptionUtil.aesDecryptECB(getApiKey()); return AESEncryptionUtil.aesDecryptECB(getApiKey());

View File

@@ -14,15 +14,19 @@ import java.time.Duration;
@Service @Service
public class AzureModelFactory implements ModelFactory, InitializingBean { public class AzureModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "AZURE"; public static final String PROVIDER = "AZURE";
public static final String DEFAULT_BASE_URL = "https://xxxx.openai.azure.com/";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder() AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
.endpoint(modelConfig.getBaseUrl()) .endpoint(modelConfig.getEndpoint())
.apiKey(modelConfig.getApiKey()) .apiKey(modelConfig.getApiKey())
.deploymentName(modelConfig.getModelName()) .deploymentName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature()) .temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut())); .maxRetries(modelConfig.getMaxRetries())
.topP(modelConfig.getTopP())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut()))
.logRequestsAndResponses(modelConfig.getLogRequests() != null && modelConfig.getLogResponses());
return builder.build(); return builder.build();
} }

View File

@@ -12,6 +12,7 @@ import org.springframework.stereotype.Service;
@Service @Service
public class DashscopeModelFactory implements ModelFactory, InitializingBean { public class DashscopeModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "DASHSCOPE"; public static final String PROVIDER = "DASHSCOPE";
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
@@ -21,6 +22,7 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean {
.modelName(modelConfig.getModelName()) .modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature() == null ? 0L : .temperature(modelConfig.getTemperature() == null ? 0L :
modelConfig.getTemperature().floatValue()) modelConfig.getTemperature().floatValue())
.topP(modelConfig.getTopP())
.build(); .build();
} }

View File

@@ -14,7 +14,7 @@ import java.time.Duration;
@Service @Service
public class LocalAiModelFactory implements ModelFactory, InitializingBean { public class LocalAiModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "LOCAL_AI"; public static final String PROVIDER = "LOCAL_AI";
public static final String DEFAULT_BASE_URL = "http://localhost:8080";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return LocalAiChatModel return LocalAiChatModel
@@ -23,6 +23,10 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
.modelName(modelConfig.getModelName()) .modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature()) .temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut())) .timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.topP(modelConfig.getTopP())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.maxRetries(modelConfig.getMaxRetries())
.build(); .build();
} }
@@ -31,6 +35,9 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
return LocalAiEmbeddingModel.builder() return LocalAiEmbeddingModel.builder()
.baseUrl(embeddingModel.getBaseUrl()) .baseUrl(embeddingModel.getBaseUrl())
.modelName(embeddingModel.getModelName()) .modelName(embeddingModel.getModelName())
.maxRetries(embeddingModel.getMaxRetries())
.logRequests(embeddingModel.getLogRequests())
.logResponses(embeddingModel.getLogResponses())
.build(); .build();
} }

View File

@@ -13,7 +13,9 @@ import java.time.Duration;
@Service @Service
public class OllamaModelFactory implements ModelFactory, InitializingBean { public class OllamaModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "OLLAMA"; public static final String PROVIDER = "OLLAMA";
public static final String DEFAULT_BASE_URL = "http://localhost:11434";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
@@ -23,6 +25,10 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
.modelName(modelConfig.getModelName()) .modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature()) .temperature(modelConfig.getTemperature())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut())) .timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build(); .build();
} }

View File

@@ -13,7 +13,9 @@ import java.time.Duration;
@Service @Service
public class OpenAiModelFactory implements ModelFactory, InitializingBean { public class OpenAiModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "OPEN_AI"; public static final String PROVIDER = "OPEN_AI";
public static final String DEFAULT_BASE_URL = "https://api.openai.com/v1";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
@@ -23,7 +25,11 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
.modelName(modelConfig.getModelName()) .modelName(modelConfig.getModelName())
.apiKey(modelConfig.keyDecrypt()) .apiKey(modelConfig.keyDecrypt())
.temperature(modelConfig.getTemperature()) .temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut())) .timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build(); .build();
} }

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.qianfan.QianfanChatModel;
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel; import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -12,10 +13,22 @@ import org.springframework.stereotype.Service;
public class QianfanModelFactory implements ModelFactory, InitializingBean { public class QianfanModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "QIANFAN"; public static final String PROVIDER = "QIANFAN";
public static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return null; return QianfanChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.secretKey(modelConfig.getSecretKey())
.endpoint(modelConfig.getEndpoint())
.modelName(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
} }
@Override @Override

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.zhipu.ZhipuAiChatModel;
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel; import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -11,10 +12,20 @@ import org.springframework.stereotype.Service;
@Service @Service
public class ZhipuModelFactory implements ModelFactory, InitializingBean { public class ZhipuModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "ZHIPU"; public static final String PROVIDER = "ZHIPU";
public static final String DEFAULT_BASE_URL = "https://open.bigmodel.cn/api/paas/v4";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return null; return ZhipuAiChatModel.builder()
.baseUrl(modelConfig.getBaseUrl())
.apiKey(modelConfig.getApiKey())
.model(modelConfig.getModelName())
.temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses())
.build();
} }
@Override @Override