(improvement)(chat) Optimize system settings. (#1454)

This commit is contained in:
lexluo09
2024-07-25 15:20:13 +08:00
committed by GitHub
parent 0f5b49f7c5
commit 335902bd1f
5 changed files with 266 additions and 32 deletions

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.common.config;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter;
@@ -8,39 +9,58 @@ import dev.langchain4j.provider.OpenAiModelFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
@Service("ChatModelParameterConfig")
@Slf4j
public class ChatModelParameterConfig extends ParameterConfig {
public static final Parameter CHAT_MODEL_PROVIDER =
new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER,
"接口协议", "",
"string", "对话模型配置",
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER));
"list", "对话模型配置",
getCandidateValues());
public static final Parameter CHAT_MODEL_BASE_URL =
new Parameter("s2.chat.model.base.url", "https://api.openai.com/v1",
"BaseUrl", "",
"string", "对话模型配置");
"BaseUrl", "", "string",
"对话模型配置", null,
getDependency(CHAT_MODEL_PROVIDER.getName(),
getCandidateValues(),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, "https://api.openai.com/v1",
OllamaModelFactory.PROVIDER, "http://localhost:11434")
)
);
public static final Parameter CHAT_MODEL_API_KEY =
new Parameter("s2.chat.model.api.key", "demo",
"ApiKey", "",
"string", "对话模型配置");
"string", "对话模型配置", null,
getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "demo"))
);
public static final Parameter CHAT_MODEL_NAME =
new Parameter("s2.chat.model.name", "gpt-3.5-turbo",
"ModelName", "",
"string", "对话模型配置");
"string", "对话模型配置", null,
getDependency(CHAT_MODEL_PROVIDER.getName(),
getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
OllamaModelFactory.PROVIDER, "qwen:0.5b")
));
public static final Parameter CHAT_MODEL_TEMPERATURE =
new Parameter("s2.chat.model.temperature", "0.0",
"Temperature", "",
"number", "对话模型配置");
"number", "对话模型配置", null,
getDependency(CHAT_MODEL_PROVIDER.getName(),
getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "0.0", OllamaModelFactory.PROVIDER, "0.0")));
public static final Parameter CHAT_MODEL_TIMEOUT =
new Parameter("s2.chat.model.timeout", "60",
@@ -72,4 +92,8 @@ public class ChatModelParameterConfig extends ParameterConfig {
.timeOut(Long.valueOf(chatModelTimeout))
.build();
}
private static ArrayList<String> getCandidateValues() {
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER);
}
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.common.config;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import com.tencent.supersonic.common.pojo.Parameter;
@@ -14,6 +15,7 @@ import dev.langchain4j.provider.ZhipuModelFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List;
@Service("EmbeddingModelParameterConfig")
@@ -23,41 +25,105 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_MODEL_PROVIDER =
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
"接口协议", "",
"string", "向量模型配置",
Lists.newArrayList(InMemoryModelFactory.PROVIDER,
"list", "向量模型配置",
getCandidateValues());
public static final Parameter EMBEDDING_MODEL_BASE_URL =
new Parameter("s2.embedding.model.base.url", "",
"BaseUrl", "",
"string", "向量模型配置", null,
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER));
public static final Parameter EMBEDDING_MODEL_BASE_URL =
new Parameter("s2.embedding.model.base.url", "",
"BaseUrl", "",
"string", "向量模型配置");
ZhipuModelFactory.PROVIDER
),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, "https://api.openai.com/v1",
OllamaModelFactory.PROVIDER, "http://localhost:11434",
AzureModelFactory.PROVIDER, "https://xxxx.openai.azure.com/",
DashscopeModelFactory.PROVIDER, "https://dashscope.aliyuncs.com/compatible-mode/v1",
QianfanModelFactory.PROVIDER, "https://aip.baidubce.com",
ZhipuModelFactory.PROVIDER, "https://open.bigmodel.cn/api/paas/v4/"
)
)
);
public static final Parameter EMBEDDING_MODEL_API_KEY =
new Parameter("s2.embedding.model.api.key", "",
"ApiKey", "",
"string", "向量模型配置");
"string", "向量模型配置", null,
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER
),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, "demo",
OllamaModelFactory.PROVIDER, "demo",
AzureModelFactory.PROVIDER, "demo",
DashscopeModelFactory.PROVIDER, "demo",
QianfanModelFactory.PROVIDER, "demo",
ZhipuModelFactory.PROVIDER, "demo"
)
));
public static final Parameter EMBEDDING_MODEL_NAME =
new Parameter("s2.embedding.model.name", InMemoryAutoConfig.BGE_SMALL_ZH,
"ModelName", "",
"string", "向量模型配置",
Lists.newArrayList(InMemoryAutoConfig.BGE_SMALL_ZH, InMemoryAutoConfig.ALL_MINILM_L6_V2));
"string", "向量模型配置", null,
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
InMemoryModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER
),
ImmutableMap.of(
InMemoryModelFactory.PROVIDER, InMemoryAutoConfig.BGE_SMALL_ZH,
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
OllamaModelFactory.PROVIDER, "all-minilm",
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
DashscopeModelFactory.PROVIDER, "text-embedding-ada-002",
QianfanModelFactory.PROVIDER, "text-embedding-ada-002",
ZhipuModelFactory.PROVIDER, "text-embedding-ada-002"
)
));
public static final Parameter EMBEDDING_MODEL_PATH =
new Parameter("s2.embedding.model.path", "",
"模型路径", "",
"string", "向量模型配置");
"string", "向量模型配置", null,
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
InMemoryModelFactory.PROVIDER
),
ImmutableMap.of(
InMemoryModelFactory.PROVIDER, "/tmp"
)
));
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
new Parameter("s2.embedding.model.vocabulary.path", "",
"词汇表路径", "",
"string", "向量模型配置");
"string", "向量模型配置", null,
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
InMemoryModelFactory.PROVIDER
),
ImmutableMap.of(
InMemoryModelFactory.PROVIDER, "/tmp"
)));
@Override
public List<Parameter> getSysParameters() {
@@ -85,4 +151,14 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
.build();
}
private static ArrayList<String> getCandidateValues() {
return Lists.newArrayList(InMemoryModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER);
}
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.common.config;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
import com.tencent.supersonic.common.pojo.Parameter;
@@ -15,7 +16,7 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_STORE_PROVIDER =
new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(),
"向量库类型", "",
"string", "向量库配置",
"list", "向量库配置",
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name()));
@@ -23,16 +24,41 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_STORE_BASE_URL =
new Parameter("s2.embedding.store.base.url", "",
"BaseUrl", "",
"string", "向量库配置");
"string", "向量库配置", null,
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name()
),
ImmutableMap.of(
EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"
)
));
public static final Parameter EMBEDDING_STORE_API_KEY =
new Parameter("s2.embedding.store.api.key", "",
"ApiKey", "",
"string", "向量库配置");
"string", "向量库配置", null,
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.MILVUS.name()
),
ImmutableMap.of(
EmbeddingStoreType.MILVUS.name(), "demo"
)
));
public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
new Parameter("s2.embedding.store.persist.path", "/tmp",
"持久化路径", "",
"string", "向量库配置");
"string", "向量库配置", null,
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.IN_MEMORY.name()
),
ImmutableMap.of(
EmbeddingStoreType.IN_MEMORY.name(), "/tmp"
)));
public static final Parameter EMBEDDING_STORE_TIMEOUT =
new Parameter("s2.embedding.store.timeout", "60",

View File

@@ -7,8 +7,10 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.env.Environment;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@Service
public abstract class ParameterConfig {
@@ -31,6 +33,7 @@ public abstract class ParameterConfig {
* 1. `system config` set with user interface
* 2. `system property` set with application.yaml file
* 3. `default value` set with parameter declaration
*
* @param parameter instance
* @return parameter value
*/
@@ -47,4 +50,22 @@ public abstract class ParameterConfig {
return value;
}
protected static List<Parameter.Dependency> getDependency(
String dependencyParameterName,
List<String> includesValue,
Map<String, String> setDefaultValue) {
Parameter.Dependency.Show show = new Parameter.Dependency.Show();
show.setIncludesValue(includesValue);
Parameter.Dependency dependency = new Parameter.Dependency();
dependency.setName(dependencyParameterName);
dependency.setShow(show);
dependency.setSetDefaultValue(setDefaultValue);
List<Parameter.Dependency> dependencies = new ArrayList<>();
dependencies.add(dependency);
return dependencies;
}
}

View File

@@ -3,13 +3,53 @@ package com.tencent.supersonic.common.pojo;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.Map;
@Data
@AllArgsConstructor
@NoArgsConstructor
/**
* 1.Password Field:
*
* dataType: string
* name: password
* require: true/false or any value/empty
* placeholder: 'Please enter the relevant configuration information'
* value: initial value
* Text Input Field:
*
* 2.dataType: string
* require: true/false or any value/empty
* placeholder: 'Please enter the relevant configuration information'
* value: initial value
* Long Text Input Field:
*
* 3.dataType: longText
* require: true/false or any value/empty
* placeholder: 'Please enter the relevant configuration information'
* value: initial value
* Number Input Field:
*
* 4.dataType: number
* require: true/false or any value/empty
* placeholder: 'Please enter the relevant configuration information'
* value: initial value
* Switch Component:
*
* 5.dataType: bool
* require: true/false or any value/empty
* value: initial value
* Select Dropdown Component:
*
* 6.dataType: list
* candidateValues: ["OPEN_AI", "OLLAMA"] or
* [{label: 'Model Name 1', value: 'OPEN_AI'}, {label: 'Model Name 2', value: 'OLLAMA'}]
* require: true/false or any value/empty
* placeholder: 'Please enter the relevant configuration information'
* value: initial value
*/
public class Parameter {
private String name;
private String defaultValue;
@@ -18,15 +58,21 @@ public class Parameter {
private String dataType;
private String module;
private String value;
private List<Object> candidateValues;
private List<String> candidateValues;
private List<Dependency> dependencies;
public Parameter(String name, String defaultValue, String comment,
String description, String dataType, String module) {
this(name, defaultValue, comment, description, dataType, module, null);
this(name, defaultValue, comment, description, dataType, module, null, null);
}
public Parameter(String name, String defaultValue, String comment, String description,
String dataType, String module, List<Object> candidateValues) {
String dataType, String module, List<String> candidateValues) {
this(name, defaultValue, comment, description, dataType, module, candidateValues, null);
}
public Parameter(String name, String defaultValue, String comment, String description,
String dataType, String module, List<String> candidateValues, List<Dependency> dependencies) {
this.name = name;
this.defaultValue = defaultValue;
this.comment = comment;
@@ -34,13 +80,54 @@ public class Parameter {
this.dataType = dataType;
this.module = module;
this.candidateValues = candidateValues;
this.dependencies = dependencies;
}
public String getValue() {
if (StringUtils.isBlank(value)) {
if (value == null || value.trim().isEmpty()) {
return defaultValue;
}
return value;
}
}
public void setValue(String value) {
this.value = value;
}
public boolean isVisible(Map<String, String> otherParameterValues) {
if (dependencies == null) {
return true;
}
for (Dependency dependency : dependencies) {
String dependentValue = otherParameterValues.get(dependency.getName());
if (dependentValue == null || !dependency.getShow().getIncludesValue().contains(dependentValue)) {
return false;
}
}
return true;
}
public void applyDefaultValue(Map<String, String> otherParameterValues) {
if (dependencies == null) {
return;
}
for (Dependency dependency : dependencies) {
String dependentValue = otherParameterValues.get(dependency.getName());
if (dependentValue != null && dependency.getSetDefaultValue().containsKey(dependentValue)) {
this.defaultValue = dependency.getSetDefaultValue().get(dependentValue);
}
}
}
@Data
public static class Dependency {
private String name;
private Show show;
private Map<String, String> setDefaultValue;
@Data
public static class Show {
private List<String> includesValue;
}
}
}