(improvement)(chat) Optimize system settings. (#1454)

This commit is contained in:
lexluo09
2024-07-25 15:20:13 +08:00
committed by GitHub
parent 0f5b49f7c5
commit 335902bd1f
5 changed files with 266 additions and 32 deletions

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.common.config; package com.tencent.supersonic.common.config;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter; import com.tencent.supersonic.common.pojo.Parameter;
@@ -8,39 +9,58 @@ import dev.langchain4j.provider.OpenAiModelFactory;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List; import java.util.List;
@Service("ChatModelParameterConfig") @Service("ChatModelParameterConfig")
@Slf4j @Slf4j
public class ChatModelParameterConfig extends ParameterConfig { public class ChatModelParameterConfig extends ParameterConfig {
public static final Parameter CHAT_MODEL_PROVIDER = public static final Parameter CHAT_MODEL_PROVIDER =
new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER, new Parameter("s2.chat.model.provider", OpenAiModelFactory.PROVIDER,
"接口协议", "", "接口协议", "",
"string", "对话模型配置", "list", "对话模型配置",
Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER)); getCandidateValues());
public static final Parameter CHAT_MODEL_BASE_URL = public static final Parameter CHAT_MODEL_BASE_URL =
new Parameter("s2.chat.model.base.url", "https://api.openai.com/v1", new Parameter("s2.chat.model.base.url", "https://api.openai.com/v1",
"BaseUrl", "", "BaseUrl", "", "string",
"string", "对话模型配置"); "对话模型配置", null,
getDependency(CHAT_MODEL_PROVIDER.getName(),
getCandidateValues(),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, "https://api.openai.com/v1",
OllamaModelFactory.PROVIDER, "http://localhost:11434")
)
);
public static final Parameter CHAT_MODEL_API_KEY = public static final Parameter CHAT_MODEL_API_KEY =
new Parameter("s2.chat.model.api.key", "demo", new Parameter("s2.chat.model.api.key", "demo",
"ApiKey", "", "ApiKey", "",
"string", "对话模型配置"); "string", "对话模型配置", null,
getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(OpenAiModelFactory.PROVIDER),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "demo"))
);
public static final Parameter CHAT_MODEL_NAME = public static final Parameter CHAT_MODEL_NAME =
new Parameter("s2.chat.model.name", "gpt-3.5-turbo", new Parameter("s2.chat.model.name", "gpt-3.5-turbo",
"ModelName", "", "ModelName", "",
"string", "对话模型配置"); "string", "对话模型配置", null,
getDependency(CHAT_MODEL_PROVIDER.getName(),
getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "gpt-3.5-turbo",
OllamaModelFactory.PROVIDER, "qwen:0.5b")
));
public static final Parameter CHAT_MODEL_TEMPERATURE = public static final Parameter CHAT_MODEL_TEMPERATURE =
new Parameter("s2.chat.model.temperature", "0.0", new Parameter("s2.chat.model.temperature", "0.0",
"Temperature", "", "Temperature", "",
"number", "对话模型配置"); "number", "对话模型配置", null,
getDependency(CHAT_MODEL_PROVIDER.getName(),
getCandidateValues(),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, "0.0", OllamaModelFactory.PROVIDER, "0.0")));
public static final Parameter CHAT_MODEL_TIMEOUT = public static final Parameter CHAT_MODEL_TIMEOUT =
new Parameter("s2.chat.model.timeout", "60", new Parameter("s2.chat.model.timeout", "60",
@@ -72,4 +92,8 @@ public class ChatModelParameterConfig extends ParameterConfig {
.timeOut(Long.valueOf(chatModelTimeout)) .timeOut(Long.valueOf(chatModelTimeout))
.build(); .build();
} }
private static ArrayList<String> getCandidateValues() {
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER);
}
} }

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.common.config; package com.tencent.supersonic.common.config;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import com.tencent.supersonic.common.pojo.Parameter; import com.tencent.supersonic.common.pojo.Parameter;
@@ -14,6 +15,7 @@ import dev.langchain4j.provider.ZhipuModelFactory;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List; import java.util.List;
@Service("EmbeddingModelParameterConfig") @Service("EmbeddingModelParameterConfig")
@@ -23,41 +25,105 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_MODEL_PROVIDER = public static final Parameter EMBEDDING_MODEL_PROVIDER =
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER,
"接口协议", "", "接口协议", "",
"string", "向量模型配置", "list", "向量模型配置",
Lists.newArrayList(InMemoryModelFactory.PROVIDER, getCandidateValues());
public static final Parameter EMBEDDING_MODEL_BASE_URL =
new Parameter("s2.embedding.model.base.url", "",
"BaseUrl", "",
"string", "向量模型配置", null,
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER)); ZhipuModelFactory.PROVIDER
),
public static final Parameter EMBEDDING_MODEL_BASE_URL = ImmutableMap.of(
new Parameter("s2.embedding.model.base.url", "", OpenAiModelFactory.PROVIDER, "https://api.openai.com/v1",
"BaseUrl", "", OllamaModelFactory.PROVIDER, "http://localhost:11434",
"string", "向量模型配置"); AzureModelFactory.PROVIDER, "https://xxxx.openai.azure.com/",
DashscopeModelFactory.PROVIDER, "https://dashscope.aliyuncs.com/compatible-mode/v1",
QianfanModelFactory.PROVIDER, "https://aip.baidubce.com",
ZhipuModelFactory.PROVIDER, "https://open.bigmodel.cn/api/paas/v4/"
)
)
);
public static final Parameter EMBEDDING_MODEL_API_KEY = public static final Parameter EMBEDDING_MODEL_API_KEY =
new Parameter("s2.embedding.model.api.key", "", new Parameter("s2.embedding.model.api.key", "",
"ApiKey", "", "ApiKey", "",
"string", "向量模型配置"); "string", "向量模型配置", null,
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER
),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, "demo",
OllamaModelFactory.PROVIDER, "demo",
AzureModelFactory.PROVIDER, "demo",
DashscopeModelFactory.PROVIDER, "demo",
QianfanModelFactory.PROVIDER, "demo",
ZhipuModelFactory.PROVIDER, "demo"
)
));
public static final Parameter EMBEDDING_MODEL_NAME = public static final Parameter EMBEDDING_MODEL_NAME =
new Parameter("s2.embedding.model.name", InMemoryAutoConfig.BGE_SMALL_ZH, new Parameter("s2.embedding.model.name", InMemoryAutoConfig.BGE_SMALL_ZH,
"ModelName", "", "ModelName", "",
"string", "向量模型配置", "string", "向量模型配置", null,
Lists.newArrayList(InMemoryAutoConfig.BGE_SMALL_ZH, InMemoryAutoConfig.ALL_MINILM_L6_V2)); getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
InMemoryModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER
),
ImmutableMap.of(
InMemoryModelFactory.PROVIDER, InMemoryAutoConfig.BGE_SMALL_ZH,
OpenAiModelFactory.PROVIDER, "text-embedding-ada-002",
OllamaModelFactory.PROVIDER, "all-minilm",
AzureModelFactory.PROVIDER, "text-embedding-ada-002",
DashscopeModelFactory.PROVIDER, "text-embedding-ada-002",
QianfanModelFactory.PROVIDER, "text-embedding-ada-002",
ZhipuModelFactory.PROVIDER, "text-embedding-ada-002"
)
));
public static final Parameter EMBEDDING_MODEL_PATH = public static final Parameter EMBEDDING_MODEL_PATH =
new Parameter("s2.embedding.model.path", "", new Parameter("s2.embedding.model.path", "",
"模型路径", "", "模型路径", "",
"string", "向量模型配置"); "string", "向量模型配置", null,
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
InMemoryModelFactory.PROVIDER
),
ImmutableMap.of(
InMemoryModelFactory.PROVIDER, "/tmp"
)
));
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH = public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
new Parameter("s2.embedding.model.vocabulary.path", "", new Parameter("s2.embedding.model.vocabulary.path", "",
"词汇表路径", "", "词汇表路径", "",
"string", "向量模型配置"); "string", "向量模型配置", null,
getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(
InMemoryModelFactory.PROVIDER
),
ImmutableMap.of(
InMemoryModelFactory.PROVIDER, "/tmp"
)));
@Override @Override
public List<Parameter> getSysParameters() { public List<Parameter> getSysParameters() {
@@ -85,4 +151,14 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
.build(); .build();
} }
private static ArrayList<String> getCandidateValues() {
return Lists.newArrayList(InMemoryModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER);
}
} }

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.common.config; package com.tencent.supersonic.common.config;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig; import com.tencent.supersonic.common.pojo.EmbeddingStoreConfig;
import com.tencent.supersonic.common.pojo.Parameter; import com.tencent.supersonic.common.pojo.Parameter;
@@ -15,7 +16,7 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_STORE_PROVIDER = public static final Parameter EMBEDDING_STORE_PROVIDER =
new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), new Parameter("s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(),
"向量库类型", "", "向量库类型", "",
"string", "向量库配置", "list", "向量库配置",
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(), Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name())); EmbeddingStoreType.CHROMA.name()));
@@ -23,16 +24,41 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_STORE_BASE_URL = public static final Parameter EMBEDDING_STORE_BASE_URL =
new Parameter("s2.embedding.store.base.url", "", new Parameter("s2.embedding.store.base.url", "",
"BaseUrl", "", "BaseUrl", "",
"string", "向量库配置"); "string", "向量库配置", null,
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name()
),
ImmutableMap.of(
EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"
)
));
public static final Parameter EMBEDDING_STORE_API_KEY = public static final Parameter EMBEDDING_STORE_API_KEY =
new Parameter("s2.embedding.store.api.key", "", new Parameter("s2.embedding.store.api.key", "",
"ApiKey", "", "ApiKey", "",
"string", "向量库配置"); "string", "向量库配置", null,
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.MILVUS.name()
),
ImmutableMap.of(
EmbeddingStoreType.MILVUS.name(), "demo"
)
));
public static final Parameter EMBEDDING_STORE_PERSIST_PATH = public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
new Parameter("s2.embedding.store.persist.path", "/tmp", new Parameter("s2.embedding.store.persist.path", "/tmp",
"持久化路径", "", "持久化路径", "",
"string", "向量库配置"); "string", "向量库配置", null,
getDependency(EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(
EmbeddingStoreType.IN_MEMORY.name()
),
ImmutableMap.of(
EmbeddingStoreType.IN_MEMORY.name(), "/tmp"
)));
public static final Parameter EMBEDDING_STORE_TIMEOUT = public static final Parameter EMBEDDING_STORE_TIMEOUT =
new Parameter("s2.embedding.store.timeout", "60", new Parameter("s2.embedding.store.timeout", "60",

View File

@@ -7,8 +7,10 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
@Service @Service
public abstract class ParameterConfig { public abstract class ParameterConfig {
@@ -31,6 +33,7 @@ public abstract class ParameterConfig {
* 1. `system config` set with user interface * 1. `system config` set with user interface
* 2. `system property` set with application.yaml file * 2. `system property` set with application.yaml file
* 3. `default value` set with parameter declaration * 3. `default value` set with parameter declaration
*
* @param parameter instance * @param parameter instance
* @return parameter value * @return parameter value
*/ */
@@ -47,4 +50,22 @@ public abstract class ParameterConfig {
return value; return value;
} }
protected static List<Parameter.Dependency> getDependency(
String dependencyParameterName,
List<String> includesValue,
Map<String, String> setDefaultValue) {
Parameter.Dependency.Show show = new Parameter.Dependency.Show();
show.setIncludesValue(includesValue);
Parameter.Dependency dependency = new Parameter.Dependency();
dependency.setName(dependencyParameterName);
dependency.setShow(show);
dependency.setSetDefaultValue(setDefaultValue);
List<Parameter.Dependency> dependencies = new ArrayList<>();
dependencies.add(dependency);
return dependencies;
}
} }

View File

@@ -3,13 +3,53 @@ package com.tencent.supersonic.common.pojo;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import org.apache.commons.lang3.StringUtils;
import java.util.List; import java.util.List;
import java.util.Map;
@Data @Data
@AllArgsConstructor @AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
/**
* 1.Password Field:
*
* dataType: string
* name: password
* require: true/false or any value/empty
* placeholder: 'Please enter the relevant configuration information'
* value: initial value
* Text Input Field:
*
* 2.dataType: string
* require: true/false or any value/empty
* placeholder: 'Please enter the relevant configuration information'
* value: initial value
* Long Text Input Field:
*
* 3.dataType: longText
* require: true/false or any value/empty
* placeholder: 'Please enter the relevant configuration information'
* value: initial value
* Number Input Field:
*
* 4.dataType: number
* require: true/false or any value/empty
* placeholder: 'Please enter the relevant configuration information'
* value: initial value
* Switch Component:
*
* 5.dataType: bool
* require: true/false or any value/empty
* value: initial value
* Select Dropdown Component:
*
* 6.dataType: list
* candidateValues: ["OPEN_AI", "OLLAMA"] or
* [{label: 'Model Name 1', value: 'OPEN_AI'}, {label: 'Model Name 2', value: 'OLLAMA'}]
* require: true/false or any value/empty
* placeholder: 'Please enter the relevant configuration information'
* value: initial value
*/
public class Parameter { public class Parameter {
private String name; private String name;
private String defaultValue; private String defaultValue;
@@ -18,15 +58,21 @@ public class Parameter {
private String dataType; private String dataType;
private String module; private String module;
private String value; private String value;
private List<Object> candidateValues; private List<String> candidateValues;
private List<Dependency> dependencies;
public Parameter(String name, String defaultValue, String comment, public Parameter(String name, String defaultValue, String comment,
String description, String dataType, String module) { String description, String dataType, String module) {
this(name, defaultValue, comment, description, dataType, module, null); this(name, defaultValue, comment, description, dataType, module, null, null);
} }
public Parameter(String name, String defaultValue, String comment, String description, public Parameter(String name, String defaultValue, String comment, String description,
String dataType, String module, List<Object> candidateValues) { String dataType, String module, List<String> candidateValues) {
this(name, defaultValue, comment, description, dataType, module, candidateValues, null);
}
public Parameter(String name, String defaultValue, String comment, String description,
String dataType, String module, List<String> candidateValues, List<Dependency> dependencies) {
this.name = name; this.name = name;
this.defaultValue = defaultValue; this.defaultValue = defaultValue;
this.comment = comment; this.comment = comment;
@@ -34,13 +80,54 @@ public class Parameter {
this.dataType = dataType; this.dataType = dataType;
this.module = module; this.module = module;
this.candidateValues = candidateValues; this.candidateValues = candidateValues;
this.dependencies = dependencies;
} }
public String getValue() { public String getValue() {
if (StringUtils.isBlank(value)) { if (value == null || value.trim().isEmpty()) {
return defaultValue; return defaultValue;
} }
return value; return value;
} }
} public void setValue(String value) {
this.value = value;
}
public boolean isVisible(Map<String, String> otherParameterValues) {
if (dependencies == null) {
return true;
}
for (Dependency dependency : dependencies) {
String dependentValue = otherParameterValues.get(dependency.getName());
if (dependentValue == null || !dependency.getShow().getIncludesValue().contains(dependentValue)) {
return false;
}
}
return true;
}
public void applyDefaultValue(Map<String, String> otherParameterValues) {
if (dependencies == null) {
return;
}
for (Dependency dependency : dependencies) {
String dependentValue = otherParameterValues.get(dependency.getName());
if (dependentValue != null && dependency.getSetDefaultValue().containsKey(dependentValue)) {
this.defaultValue = dependency.getSetDefaultValue().get(dependentValue);
}
}
}
@Data
public static class Dependency {
private String name;
private Show show;
private Map<String, String> setDefaultValue;
@Data
public static class Show {
private List<String> includesValue;
}
}
}