(improvement)(chat) Remove langchain4j configuration file and perform all configuration for the large model through the UI interface. (#1442)

This commit is contained in:
lexluo09
2024-07-20 21:30:46 +08:00
committed by GitHub
parent 3797cc2ce8
commit d64ed02df9
12 changed files with 35 additions and 170 deletions

View File

@@ -14,6 +14,8 @@ import java.util.List;
@Slf4j
public class ChatModelParameterConfig extends ParameterConfig {
public static final Parameter CHAT_MODEL_PROVIDER =
new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER,
"接口协议", "",
@@ -21,17 +23,17 @@ public class ChatModelParameterConfig extends ParameterConfig {
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER));
public static final Parameter CHAT_MODEL_BASE_URL =
new Parameter("s2.chat.model.base.url", "",
new Parameter("s2.chat.model.base.url", "https://api.openai.com/v1",
"BaseUrl", "",
"string", "对话模型配置");
public static final Parameter CHAT_MODEL_API_KEY =
new Parameter("s2.chat.model.api.key", "",
new Parameter("s2.chat.model.api.key", "demo",
"ApiKey", "",
"string", "对话模型配置");
public static final Parameter CHAT_MODEL_NAME =
new Parameter("s2.chat.model.name", "",
new Parameter("s2.chat.model.name", "gpt-3.5-turbo",
"ModelName", "",
"string", "对话模型配置");

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.common.config;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.inmemory.spring.InMemoryAutoConfig;
import dev.langchain4j.provider.AzureModelFactory;
import dev.langchain4j.provider.DashscopeModelFactory;
import dev.langchain4j.provider.InMemoryModelFactory;
@@ -20,7 +21,7 @@ import java.util.List;
public class EmbeddingModelParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_MODEL_PROVIDER =
new Parameter("s2.embedding.model.provider", "",
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
"接口协议", "",
"string", "向量模型配置",
Lists.newArrayList(InMemoryModelFactory.PROVIDER,
@@ -43,9 +44,10 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_MODEL_NAME =
new Parameter("s2.embedding.model.name", "",
new Parameter("s2.embedding.model.name", InMemoryAutoConfig.BGE_SMALL_ZH,
"ModelName", "",
"string", "向量模型配置");
"string", "向量模型配置",
Lists.newArrayList(InMemoryAutoConfig.BGE_SMALL_ZH, InMemoryAutoConfig.ALL_MINILM_L6_V2));
public static final Parameter EMBEDDING_MODEL_PATH =
new Parameter("s2.embedding.model.path", "",

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.common.config;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.store.embedding.EmbeddingStoreType;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@@ -12,9 +13,12 @@ import java.util.List;
@Slf4j
public class EmbeddingStoreParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_STORE_PROVIDER =
new Parameter("s2.embedding.store.provider", "",
new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(),
"向量库类型", "",
"string", "向量库配置");
"string", "向量库配置",
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name()));
public static final Parameter EMBEDDING_STORE_BASE_URL =
new Parameter("s2.embedding.store.base.url", "",

View File

@@ -1,5 +1,6 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelParameterConfig;
import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
@@ -18,11 +19,16 @@ public class ModelProvider {
factories.put(provider, modelFactory);
}
public static ChatLanguageModel getChatModel() {
return getChatModel(null);
}
public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) {
if (modelConfig == null
|| StringUtils.isBlank(modelConfig.getProvider())
if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider())
|| StringUtils.isBlank(modelConfig.getBaseUrl())) {
return ContextUtils.getBean(ChatLanguageModel.class);
ChatModelParameterConfig parameterConfig = ContextUtils.getBean(
ChatModelParameterConfig.class);
modelConfig = parameterConfig.convert();
}
ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase());
if (modelFactory != null) {
@@ -33,15 +39,14 @@ public class ModelProvider {
}
public static EmbeddingModel getEmbeddingModel() {
EmbeddingModelParameterConfig parameterConfig = ContextUtils.getBean(
EmbeddingModelParameterConfig.class);
EmbeddingModelConfig embeddingModelConfig = parameterConfig.convert();
return getEmbeddingModel(embeddingModelConfig);
return getEmbeddingModel(null);
}
public static EmbeddingModel getEmbeddingModel(EmbeddingModelConfig embeddingModel) {
if (embeddingModel == null || StringUtils.isBlank(embeddingModel.getProvider())) {
return ContextUtils.getBean(EmbeddingModel.class);
EmbeddingModelParameterConfig parameterConfig = ContextUtils.getBean(
EmbeddingModelParameterConfig.class);
embeddingModel = parameterConfig.convert();
}
ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase());
if (modelFactory != null) {