[feature][chat]Refactor chat model config related codes.#1739

This commit is contained in:
jerryjzhang
2024-10-09 17:27:07 +08:00
parent 60b0a1a1a1
commit 248f4f83f6
53 changed files with 275 additions and 251 deletions

View File

@@ -29,7 +29,7 @@ public class User {
return new User(id, name, name, name, 0); return new User(id, name, name, name, 0);
} }
public static User getFakeUser() { public static User getDefaultUser() {
return new User(1L, "admin", "admin", "admin@email", 1); return new User(1L, "admin", "admin", "admin@email", 1);
} }

View File

@@ -69,7 +69,7 @@ public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor
private void setFakerUser(HttpServletRequest request) { private void setFakerUser(HttpServletRequest request) {
String token = userTokenUtils.generateAdminToken(request); String token = userTokenUtils.generateAdminToken(request);
reflectSetParam(request, authenticationConfig.getTokenHttpHeaderKey(), token); reflectSetParam(request, authenticationConfig.getTokenHttpHeaderKey(), token);
setContext(User.getFakeUser().getName(), request); setContext(User.getDefaultUser().getName(), request);
} }
private void setContext(String userName, HttpServletRequest request) { private void setContext(String userName, HttpServletRequest request) {

View File

@@ -17,11 +17,11 @@ public class FakeUserStrategy implements UserStrategy {
@Override @Override
public User findUser(HttpServletRequest request, HttpServletResponse response) { public User findUser(HttpServletRequest request, HttpServletResponse response) {
return User.getFakeUser(); return User.getDefaultUser();
} }
@Override @Override
public User findUser(String token, String appKey) { public User findUser(String token, String appKey) {
return User.getFakeUser(); return User.getDefaultUser();
} }
} }

View File

@@ -3,8 +3,6 @@ package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.pojo.RecordInfo; import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import lombok.Data; import lombok.Data;
@@ -17,22 +15,21 @@ import java.util.stream.Collectors;
public class Agent extends RecordInfo { public class Agent extends RecordInfo {
private Integer id; private Integer id;
private Integer enableSearch;
private Integer enableMemoryReview;
private String name; private String name;
private String description; private String description;
/** 0 offline, 1 online */ /** 0 offline, 1 online */
private Integer status; private Integer status;
private List<String> examples; private List<String> examples;
private String agentConfig; private Integer enableSearch;
private Map<ChatModelType, Integer> modelConfig = Collections.EMPTY_MAP; private Integer enableMemoryReview;
private String toolConfig;
private Map<ChatModelType, Integer> chatModelConfig = Collections.EMPTY_MAP;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig; private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig; private VisualConfig visualConfig;
public List<String> getTools(AgentToolType type) { public List<String> getTools(AgentToolType type) {
Map map = JSONObject.parseObject(agentConfig, Map.class); Map map = JSONObject.parseObject(toolConfig, Map.class);
if (CollectionUtils.isEmpty(map) || map.get("tools") == null) { if (CollectionUtils.isEmpty(map) || map.get("tools") == null) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
@@ -84,7 +81,7 @@ public class Agent extends RecordInfo {
} }
public boolean containsAnyTool() { public boolean containsAnyTool() {
Map map = JSONObject.parseObject(agentConfig, Map.class); Map map = JSONObject.parseObject(toolConfig, Map.class);
if (CollectionUtils.isEmpty(map)) { if (CollectionUtils.isEmpty(map)) {
return false; return false;
} }

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.common.config; package com.tencent.supersonic.chat.server.agent;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Data; import lombok.Data;

View File

@@ -10,7 +10,7 @@ import java.util.List;
@Data @Data
@NoArgsConstructor @NoArgsConstructor
@AllArgsConstructor @AllArgsConstructor
public class AgentConfig { public class ToolConfig {
List<AgentTool> tools = Lists.newArrayList(); private List<AgentTool> tools = Lists.newArrayList();
} }

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.common.config; package com.tencent.supersonic.chat.server.agent;
import lombok.Data; import lombok.Data;

View File

@@ -1,77 +1,54 @@
package com.tencent.supersonic.common.config; package com.tencent.supersonic.chat.server.config;
import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter; import com.tencent.supersonic.common.pojo.Parameter;
import dev.langchain4j.provider.AzureModelFactory; import dev.langchain4j.provider.*;
import dev.langchain4j.provider.DashscopeModelFactory;
import dev.langchain4j.provider.LocalAiModelFactory;
import dev.langchain4j.provider.OllamaModelFactory;
import dev.langchain4j.provider.OpenAiModelFactory;
import dev.langchain4j.provider.QianfanModelFactory;
import dev.langchain4j.provider.ZhipuModelFactory;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map;
@Service("ChatModelParameterConfig") public class ChatModelParameters {
@Slf4j private static final String MODULE_NAME = "对话模型配置";
public class ChatModelParameterConfig extends ParameterConfig {
public static final Parameter CHAT_MODEL_PROVIDER = new Parameter("s2.chat.model.provider", public static final Parameter CHAT_MODEL_PROVIDER =
OpenAiModelFactory.PROVIDER, "接口协议", "", "list", "对话模型配置", getCandidateValues()); new Parameter("provider", ModelProvider.DEMO_CHAT_MODEL.getProvider(), "接口协议", "",
"list", MODULE_NAME, getCandidateValues());
public static final Parameter CHAT_MODEL_BASE_URL = public static final Parameter CHAT_MODEL_BASE_URL =
new Parameter("s2.chat.model.base.url", OpenAiModelFactory.DEFAULT_BASE_URL, "BaseUrl", new Parameter("baseUrl", ModelProvider.DEMO_CHAT_MODEL.getBaseUrl(), "BaseUrl", "",
"", "string", "对话模型配置", null, getBaseUrlDependency()); "string", MODULE_NAME, null, getBaseUrlDependency());
public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("s2.chat.model.endpoint",
"llama_2_70b", "Endpoint", "", "string", "对话模型配置", null, getEndpointDependency());
public static final Parameter CHAT_MODEL_API_KEY = new Parameter("s2.chat.model.api.key", DEMO,
"ApiKey", "", "password", "对话模型配置", null, getApiKeyDependency());
public static final Parameter CHAT_MODEL_SECRET_KEY = new Parameter("s2.chat.model.secretKey",
"demo", "SecretKey", "", "password", "对话模型配置", null, getSecretKeyDependency());
public static final Parameter CHAT_MODEL_NAME = new Parameter("s2.chat.model.name", public static final Parameter CHAT_MODEL_NAME =
"gpt-4o-mini", "ModelName", "", "string", "对话模型配置", null, getModelNameDependency()); new Parameter("modelName", ModelProvider.DEMO_CHAT_MODEL.getModelName(), "ModelName",
"", "string", MODULE_NAME, null, getModelNameDependency());
public static final Parameter CHAT_MODEL_ENABLE_SEARCH = public static final Parameter CHAT_MODEL_API_KEY =
new Parameter("s2.chat.model.enableSearch", "false", "是否启用搜索增强功能设为false表示不启用", "", new Parameter("apiKey", ModelProvider.DEMO_CHAT_MODEL.getApiKey(), "ApiKey", "",
"bool", "对话模型配置", null, getEnableSearchDependency()); "password", MODULE_NAME, null, getApiKeyDependency());
public static final Parameter CHAT_MODEL_TEMPERATURE = new Parameter( public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("endpoint", "llama_2_70b",
"s2.chat.model.temperature", "0.0", "Temperature", "", "slider", "对话模型配置"); "Endpoint", "", "string", MODULE_NAME, null, getEndpointDependency());
public static final Parameter CHAT_MODEL_SECRET_KEY = new Parameter("secretKey", "demo",
"SecretKey", "", "password", MODULE_NAME, null, getSecretKeyDependency());
public static final Parameter CHAT_MODEL_ENABLE_SEARCH = new Parameter("enableSearch", "false",
"是否启用搜索增强功能设为false表示不启用", "", "bool", MODULE_NAME, null, getEnableSearchDependency());
public static final Parameter CHAT_MODEL_TEMPERATURE =
new Parameter("temperature", "0.0", "Temperature", "", "slider", MODULE_NAME);
public static final Parameter CHAT_MODEL_TIMEOUT = public static final Parameter CHAT_MODEL_TIMEOUT =
new Parameter("s2.chat.model.timeout", "60", "超时时间(秒)", "", "number", "对话模型配置"); new Parameter("timeOut", "60", "超时时间(秒)", "", "number", MODULE_NAME);
@Override public static List<Parameter> getParameters() {
public List<Parameter> getSysParameters() {
return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT, return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME, CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT); CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT);
} }
public ChatModelConfig convert() {
String chatModelProvider = getParameterValue(CHAT_MODEL_PROVIDER);
String chatModelBaseUrl = getParameterValue(CHAT_MODEL_BASE_URL);
String chatModelApiKey = getParameterValue(CHAT_MODEL_API_KEY);
String chatModelName = getParameterValue(CHAT_MODEL_NAME);
String chatModelTemperature = getParameterValue(CHAT_MODEL_TEMPERATURE);
String chatModelTimeout = getParameterValue(CHAT_MODEL_TIMEOUT);
String endpoint = getParameterValue(CHAT_MODEL_ENDPOINT);
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
String enableSearch = getParameterValue(CHAT_MODEL_ENABLE_SEARCH);
return ChatModelConfig.builder().provider(chatModelProvider).baseUrl(chatModelBaseUrl)
.apiKey(chatModelApiKey).modelName(chatModelName)
.enableSearch(Boolean.valueOf(enableSearch))
.temperature(Double.valueOf(chatModelTemperature))
.timeOut(Long.valueOf(chatModelTimeout)).endpoint(endpoint).secretKey(secretKey)
.build();
}
private static List<String> getCandidateValues() { private static List<String> getCandidateValues() {
return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER, return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER, QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
@@ -95,9 +72,13 @@ public class ChatModelParameterConfig extends ParameterConfig {
Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER, Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER,
AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER), AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER),
ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER, ImmutableMap.of(OpenAiModelFactory.PROVIDER,
DEMO, ZhipuModelFactory.PROVIDER, DEMO, LocalAiModelFactory.PROVIDER, DEMO, ModelProvider.DEMO_CHAT_MODEL.getApiKey(), QianfanModelFactory.PROVIDER,
AzureModelFactory.PROVIDER, DEMO, DashscopeModelFactory.PROVIDER, DEMO)); ModelProvider.DEMO_CHAT_MODEL.getApiKey(), ZhipuModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), LocalAiModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), AzureModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey(), DashscopeModelFactory.PROVIDER,
ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
} }
private static List<Parameter.Dependency> getModelNameDependency() { private static List<Parameter.Dependency> getModelNameDependency() {
@@ -125,7 +106,22 @@ public class ChatModelParameterConfig extends ParameterConfig {
private static List<Parameter.Dependency> getSecretKeyDependency() { private static List<Parameter.Dependency> getSecretKeyDependency() {
return getDependency(CHAT_MODEL_PROVIDER.getName(), return getDependency(CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER), Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap.of(
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)); QianfanModelFactory.PROVIDER, ModelProvider.DEMO_CHAT_MODEL.getApiKey()));
}
private static List<Parameter.Dependency> getDependency(String dependencyParameterName,
List<String> includesValue, Map<String, String> setDefaultValue) {
Parameter.Dependency.Show show = new Parameter.Dependency.Show();
show.setIncludesValue(includesValue);
Parameter.Dependency dependency = new Parameter.Dependency();
dependency.setName(dependencyParameterName);
dependency.setShow(show);
dependency.setSetDefaultValue(setDefaultValue);
List<Parameter.Dependency> dependencies = new ArrayList<>();
dependencies.add(dependency);
return dependencies;
} }
} }

View File

@@ -11,5 +11,5 @@ public class ParserConfig extends ParameterConfig {
public static final Parameter PARSER_MULTI_TURN_ENABLE = public static final Parameter PARSER_MULTI_TURN_ENABLE =
new Parameter("s2.parser.multi-turn.enable", "false", "是否开启多轮对话", "开启多轮对话将消耗更多token", new Parameter("s2.parser.multi-turn.enable", "false", "是否开启多轮对话", "开启多轮对话将消耗更多token",
"bool", "Parser相关配置"); "bool", "语义解析配置");
} }

View File

@@ -10,43 +10,34 @@ import java.util.Date;
@Data @Data
@TableName("s2_agent") @TableName("s2_agent")
public class AgentDO { public class AgentDO {
/** */
@TableId(type = IdType.AUTO) @TableId(type = IdType.AUTO)
private Integer id; private Integer id;
/** */
private String name; private String name;
/** */
private String description; private String description;
/** 0 offline, 1 online */ /** 0 offline, 1 online */
private Integer status; private Integer status;
/** */
private String examples; private String examples;
/** */
private String config;
/** */
private String createdBy; private String createdBy;
/** */
private Date createdAt; private Date createdAt;
/** */
private String updatedBy; private String updatedBy;
/** */
private Date updatedAt; private Date updatedAt;
/** */
private Integer enableSearch; private Integer enableSearch;
private Integer enableMemoryReview; private Integer enableMemoryReview;
private String modelConfig; private String toolConfig;
private String chatModelConfig;
private String multiTurnConfig; private String multiTurnConfig;

View File

@@ -59,8 +59,4 @@ public class AgentController {
return AgentToolType.getToolTypes(); return AgentToolType.getToolTypes();
} }
@PostMapping("/testLLMConn")
public boolean testLLMConn(@RequestBody ChatModelConfig modelConfig) {
return ModelConfigHelper.testConnection(modelConfig);
}
} }

View File

@@ -6,10 +6,12 @@ import javax.servlet.http.HttpServletResponse;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.response.ChatModelTypeResp; import com.tencent.supersonic.chat.api.pojo.response.ChatModelTypeResp;
import com.tencent.supersonic.chat.server.config.ChatModelParameters;
import com.tencent.supersonic.chat.server.pojo.ChatModel; import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.chat.server.service.ChatModelService; import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter;
import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
@@ -56,6 +58,11 @@ public class ChatModelController {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@RequestMapping("/getModelParameters")
public List<Parameter> getModelParameters() {
return ChatModelParameters.getParameters();
}
@PostMapping("/testConnection") @PostMapping("/testConnection")
public boolean testConnection(@RequestBody ChatModelConfig modelConfig) { public boolean testConnection(@RequestBody ChatModelConfig modelConfig) {
return ModelConfigHelper.testConnection(modelConfig); return ModelConfigHelper.testConnection(modelConfig);

View File

@@ -6,6 +6,8 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.PromptConfig;
import com.tencent.supersonic.chat.server.agent.VisualConfig;
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO; import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper; import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
@@ -14,8 +16,6 @@ import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.chat.server.service.ChatQueryService; import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy; import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy;
@@ -140,10 +140,10 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
} }
Agent agent = new Agent(); Agent agent = new Agent();
BeanUtils.copyProperties(agentDO, agent); BeanUtils.copyProperties(agentDO, agent);
agent.setAgentConfig(agentDO.getConfig()); agent.setToolConfig(agentDO.getToolConfig());
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class)); agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
agent.setModelConfig( agent.setChatModelConfig(
JsonUtil.toMap(agentDO.getModelConfig(), ChatModelType.class, Integer.class)); JsonUtil.toMap(agentDO.getChatModelConfig(), ChatModelType.class, Integer.class));
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class)); agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
agent.setMultiTurnConfig( agent.setMultiTurnConfig(
JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class)); JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
@@ -154,9 +154,9 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
private AgentDO convert(Agent agent) { private AgentDO convert(Agent agent) {
AgentDO agentDO = new AgentDO(); AgentDO agentDO = new AgentDO();
BeanUtils.copyProperties(agent, agentDO); BeanUtils.copyProperties(agent, agentDO);
agentDO.setConfig(agent.getAgentConfig()); agentDO.setToolConfig(agent.getToolConfig());
agentDO.setExamples(JsonUtil.toString(agent.getExamples())); agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
agentDO.setModelConfig(JsonUtil.toString(agent.getModelConfig())); agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatModelConfig()));
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig())); agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig())); agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig())); agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig()));

View File

@@ -47,6 +47,7 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
chatModelDO.setAdmin(user.getName()); chatModelDO.setAdmin(user.getName());
} }
save(chatModelDO); save(chatModelDO);
chatModel.setId(chatModelDO.getId());
return chatModel; return chatModel;
} }

View File

@@ -158,7 +158,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId()); executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
executeReq.setQueryText(chatParseReq.getQueryText()); executeReq.setQueryText(chatParseReq.getQueryText());
executeReq.setChatId(chatParseReq.getChatId()); executeReq.setChatId(chatParseReq.getChatId());
executeReq.setUser(User.getFakeUser()); executeReq.setUser(User.getDefaultUser());
executeReq.setAgentId(chatParseReq.getAgentId()); executeReq.setAgentId(chatParseReq.getAgentId());
executeReq.setSaveAnswer(true); executeReq.setSaveAnswer(true);
return execute(executeReq); return execute(executeReq);

View File

@@ -29,8 +29,8 @@ public class ModelConfigHelper {
public static ChatModelConfig getChatModelConfig(Agent agent, ChatModelType modelType) { public static ChatModelConfig getChatModelConfig(Agent agent, ChatModelType modelType) {
ChatModelConfig chatModelConfig = null; ChatModelConfig chatModelConfig = null;
if (agent.getModelConfig().containsKey(modelType)) { if (agent.getChatModelConfig().containsKey(modelType)) {
Integer chatModelId = agent.getModelConfig().get(modelType); Integer chatModelId = agent.getChatModelConfig().get(modelType);
ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class); ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class);
chatModelConfig = chatModelService.getChatModel(chatModelId).getConfig(); chatModelConfig = chatModelService.getChatModel(chatModelId).getConfig();
} }

View File

@@ -50,7 +50,7 @@ public class QueryReqConverter {
queryNLReq.setMapInfo(queryNLReq.getMapInfo()); queryNLReq.setMapInfo(queryNLReq.getMapInfo());
} }
queryNLReq.setModelConfig(chatModelConfig); queryNLReq.setModelConfig(chatModelConfig);
queryNLReq.setPromptConfig(agent.getPromptConfig()); queryNLReq.setCustomPrompt(agent.getPromptConfig().getPromptTemplate());
if (chatCtx != null) { if (chatCtx != null) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo()); queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
} }

View File

@@ -21,30 +21,32 @@ import java.util.List;
@Service("EmbeddingModelParameterConfig") @Service("EmbeddingModelParameterConfig")
@Slf4j @Slf4j
public class EmbeddingModelParameterConfig extends ParameterConfig { public class EmbeddingModelParameterConfig extends ParameterConfig {
private static final String MODULE_NAME = "嵌入模型配置";
public static final Parameter EMBEDDING_MODEL_PROVIDER = public static final Parameter EMBEDDING_MODEL_PROVIDER =
new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, "接口协议", "", new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, "接口协议", "",
"list", "向量模型配置", getCandidateValues()); "list", MODULE_NAME, getCandidateValues());
public static final Parameter EMBEDDING_MODEL_BASE_URL = public static final Parameter EMBEDDING_MODEL_BASE_URL =
new Parameter("s2.embedding.model.base.url", "", "BaseUrl", "", "string", "向量模型配置", new Parameter("s2.embedding.model.base.url", "", "BaseUrl", "", "string", MODULE_NAME,
null, getBaseUrlDependency()); null, getBaseUrlDependency());
public static final Parameter EMBEDDING_MODEL_API_KEY = public static final Parameter EMBEDDING_MODEL_API_KEY =
new Parameter("s2.embedding.model.api.key", "", "ApiKey", "", "password", "向量模型配置", new Parameter("s2.embedding.model.api.key", "", "ApiKey", "", "password", MODULE_NAME,
null, getApiKeyDependency()); null, getApiKeyDependency());
public static final Parameter EMBEDDING_MODEL_SECRET_KEY = public static final Parameter EMBEDDING_MODEL_SECRET_KEY =
new Parameter("s2.embedding.model.secretKey", "demo", "SecretKey", "", "password", new Parameter("s2.embedding.model.secretKey", "demo", "SecretKey", "", "password",
"向量模型配置", null, getSecretKeyDependency()); MODULE_NAME, null, getSecretKeyDependency());
public static final Parameter EMBEDDING_MODEL_NAME = public static final Parameter EMBEDDING_MODEL_NAME =
new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH, new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
"ModelName", "", "string", "向量模型配置", null, getModelNameDependency()); "ModelName", "", "string", MODULE_NAME, null, getModelNameDependency());
public static final Parameter EMBEDDING_MODEL_PATH = new Parameter("s2.embedding.model.path", public static final Parameter EMBEDDING_MODEL_PATH = new Parameter("s2.embedding.model.path",
"", "模型路径", "", "string", "向量模型配置", null, getModelPathDependency()); "", "模型路径", "", "string", MODULE_NAME, null, getModelPathDependency());
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH = public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
new Parameter("s2.embedding.model.vocabulary.path", "", "词汇表路径", "", "string", "向量模型配置", new Parameter("s2.embedding.model.vocabulary.path", "", "词汇表路径", "", "string",
null, getModelPathDependency()); MODULE_NAME, null, getModelPathDependency());
@Override @Override
public List<Parameter> getSysParameters() { public List<Parameter> getSysParameters() {

View File

@@ -15,32 +15,34 @@ import java.util.List;
@Service("EmbeddingStoreParameterConfig") @Service("EmbeddingStoreParameterConfig")
@Slf4j @Slf4j
public class EmbeddingStoreParameterConfig extends ParameterConfig { public class EmbeddingStoreParameterConfig extends ParameterConfig {
private static final String MODULE_NAME = "向量数据库配置";
public static final Parameter EMBEDDING_STORE_PROVIDER = new Parameter( public static final Parameter EMBEDDING_STORE_PROVIDER = new Parameter(
"s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), "向量库类型", "s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), "向量库类型",
"目前支持三种类型IN_MEMORY、MILVUS、CHROMA", "list", "向量库配置", getCandidateValues()); "目前支持三种类型IN_MEMORY、MILVUS、CHROMA", "list", MODULE_NAME, getCandidateValues());
public static final Parameter EMBEDDING_STORE_BASE_URL = public static final Parameter EMBEDDING_STORE_BASE_URL =
new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", "向量库配置", null, new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", MODULE_NAME,
getBaseUrlDependency()); null, getBaseUrlDependency());
public static final Parameter EMBEDDING_STORE_API_KEY = public static final Parameter EMBEDDING_STORE_API_KEY =
new Parameter("s2.embedding.store.api.key", "", "ApiKey", "", "password", "向量库配置", null, new Parameter("s2.embedding.store.api.key", "", "ApiKey", "", "password", MODULE_NAME,
getApiKeyDependency()); null, getApiKeyDependency());
public static final Parameter EMBEDDING_STORE_PERSIST_PATH = public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
new Parameter("s2.embedding.store.persist.path", "", "持久化路径", new Parameter("s2.embedding.store.persist.path", "", "持久化路径",
"默认不持久化,如需持久化请填写持久化路径。" + "注意:如果变更了向量模型需删除该路径下已保存的文件或修改持久化路径", "string", "默认不持久化,如需持久化请填写持久化路径。" + "注意:如果变更了向量模型需删除该路径下已保存的文件或修改持久化路径", "string",
"向量库配置", null, getPathDependency()); MODULE_NAME, null, getPathDependency());
public static final Parameter EMBEDDING_STORE_TIMEOUT = public static final Parameter EMBEDDING_STORE_TIMEOUT =
new Parameter("s2.embedding.store.timeout", "60", "超时时间(秒)", "", "number", "向量库配置"); new Parameter("s2.embedding.store.timeout", "60", "超时时间(秒)", "", "number", MODULE_NAME);
public static final Parameter EMBEDDING_STORE_DIMENSION = public static final Parameter EMBEDDING_STORE_DIMENSION =
new Parameter("s2.embedding.store.dimension", "", "纬度", "", "number", "向量库配置", null, new Parameter("s2.embedding.store.dimension", "", "纬度", "", "number", MODULE_NAME, null,
getDimensionDependency()); getDimensionDependency());
public static final Parameter EMBEDDING_STORE_DATABASE_NAME = public static final Parameter EMBEDDING_STORE_DATABASE_NAME =
new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string", new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string",
"向量库配置", null, getDatabaseNameDependency()); MODULE_NAME, null, getDatabaseNameDependency());
@Override @Override
public List<Parameter> getSysParameters() { public List<Parameter> getSysParameters() {

View File

@@ -1,6 +1,5 @@
package dev.langchain4j.provider; package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelParameterConfig;
import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig; import com.tencent.supersonic.common.config.EmbeddingModelParameterConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig; import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
@@ -14,6 +13,10 @@ import java.util.Map;
public class ModelProvider { public class ModelProvider {
public static final ChatModelConfig DEMO_CHAT_MODEL =
ChatModelConfig.builder().provider("open_ai").baseUrl("https://api.openai.com/v1")
.apiKey("demo").modelName("gpt-4o-mini").temperature(0.0).timeOut(60L).build();
private static final Map<String, ModelFactory> factories = new HashMap<>(); private static final Map<String, ModelFactory> factories = new HashMap<>();
public static void add(String provider, ModelFactory modelFactory) { public static void add(String provider, ModelFactory modelFactory) {
@@ -27,9 +30,7 @@ public class ModelProvider {
public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) { public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) {
if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider()) if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider())
|| StringUtils.isBlank(modelConfig.getBaseUrl())) { || StringUtils.isBlank(modelConfig.getBaseUrl())) {
ChatModelParameterConfig parameterConfig = modelConfig = DEMO_CHAT_MODEL;
ContextUtils.getBean(ChatModelParameterConfig.class);
modelConfig = parameterConfig.convert();
} }
ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase()); ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase());
if (modelFactory != null) { if (modelFactory != null) {

View File

@@ -289,7 +289,7 @@ def build_agent(dataSetId):
# "description":"DuSQL", # "description":"DuSQL",
# "status":1, # "status":1,
# "examples":[], # "examples":[],
# "agentConfig":json.dumps({ # "toolConfig":json.dumps({
# "tools":[{ # "tools":[{
# "id":1, # "id":1,
# "type":"NL2SQL_LLM", # "type":"NL2SQL_LLM",
@@ -303,7 +303,7 @@ def build_agent(dataSetId):
"description":"DuSQL", "description":"DuSQL",
"status":1, "status":1,
"examples":[], "examples":[],
"agentConfig":json.dumps({ "toolConfig":json.dumps({
"tools":[{ "tools":[{
"id":1, "id":1,
"type":"NL2SQL_LLM", "type":"NL2SQL_LLM",

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
@@ -28,7 +27,7 @@ public class QueryNLReq {
private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SchemaMapInfo mapInfo = new SchemaMapInfo();
private QueryDataType queryDataType = QueryDataType.ALL; private QueryDataType queryDataType = QueryDataType.ALL;
private ChatModelConfig modelConfig; private ChatModelConfig modelConfig;
private PromptConfig promptConfig; private String customPrompt;
private List<Text2SQLExemplar> dynamicExemplars = Lists.newArrayList(); private List<Text2SQLExemplar> dynamicExemplars = Lists.newArrayList();
private SemanticParseInfo contextParseInfo; private SemanticParseInfo contextParseInfo;
} }

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
@@ -54,7 +53,7 @@ public class ChatQueryContext {
private ChatWorkflowState chatWorkflowState; private ChatWorkflowState chatWorkflowState;
private QueryDataType queryDataType = QueryDataType.ALL; private QueryDataType queryDataType = QueryDataType.ALL;
private ChatModelConfig modelConfig; private ChatModelConfig modelConfig;
private PromptConfig promptConfig; private String customPrompt;
private List<Text2SQLExemplar> dynamicExemplars; private List<Text2SQLExemplar> dynamicExemplars;
public List<SemanticQuery> getCandidateQueries() { public List<SemanticQuery> getCandidateQueries() {

View File

@@ -14,40 +14,40 @@ public class ParserConfig extends ParameterConfig {
public static final Parameter PARSER_STRATEGY_TYPE = public static final Parameter PARSER_STRATEGY_TYPE =
new Parameter("s2.parser.s2sql.strategy", "ONE_PASS_SELF_CONSISTENCY", "LLM解析生成S2SQL策略", new Parameter("s2.parser.s2sql.strategy", "ONE_PASS_SELF_CONSISTENCY", "LLM解析生成S2SQL策略",
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql", "list", "Parser相关配置", "ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql", "list", "语义解析配置",
Lists.newArrayList("ONE_PASS_SELF_CONSISTENCY")); Lists.newArrayList("ONE_PASS_SELF_CONSISTENCY"));
public static final Parameter PARSER_LINKING_VALUE_ENABLE = public static final Parameter PARSER_LINKING_VALUE_ENABLE =
new Parameter("s2.parser.linking.value.enable", "true", "是否将Mapper探测识别到的维度值提供给大模型", new Parameter("s2.parser.linking.value.enable", "true", "是否将Mapper探测识别到的维度值提供给大模型",
"为了数据安全考虑, 这里可进行开关选择", "bool", "Parser相关配置"); "为了数据安全考虑, 这里可进行开关选择", "bool", "语义解析配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD = public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD =
new Parameter("s2.parser.text.length.threshold", "10", "用户输入文本长短阈值", "文本超过该阈值为长文本", new Parameter("s2.parser.text.length.threshold", "10", "用户输入文本长短阈值", "文本超过该阈值为长文本",
"number", "Parser相关配置"); "number", "语义解析配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT = public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT =
new Parameter("s2.parser.text.threshold.short", "0.5", "短文本匹配阈值", new Parameter("s2.parser.text.threshold.short", "0.5", "短文本匹配阈值",
"由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用," "由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,"
+ "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser", + "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置"); "number", "语义解析配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG = public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG =
new Parameter("s2.parser.text.threshold.long", "0.8", "长文本匹配阈值", new Parameter("s2.parser.text.threshold.long", "0.8", "长文本匹配阈值",
"如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "Parser相关配置"); "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "语义解析配置");
public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER = new Parameter( public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER = new Parameter(
"s2.parser.exemplar-recall.number", "10", "exemplar召回个数", "", "number", "Parser相关配置"); "s2.parser.exemplar-recall.number", "10", "exemplar召回个数", "", "number", "语义解析配置");
public static final Parameter PARSER_FEW_SHOT_NUMBER = public static final Parameter PARSER_FEW_SHOT_NUMBER =
new Parameter("s2.parser.few-shot.number", "3", "few-shot样例个数", "样例越多效果可能越好但token消耗越大", new Parameter("s2.parser.few-shot.number", "3", "few-shot样例个数", "样例越多效果可能越好但token消耗越大",
"number", "Parser相关配置"); "number", "语义解析配置");
public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER = public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER =
new Parameter("s2.parser.self-consistency.number", "1", "self-consistency执行个数", new Parameter("s2.parser.self-consistency.number", "1", "self-consistency执行个数",
"执行越多效果可能越好但token消耗越大", "number", "Parser相关配置"); "执行越多效果可能越好但token消耗越大", "number", "语义解析配置");
public static final Parameter PARSER_SHOW_COUNT = new Parameter("s2.parser.show.count", "3", public static final Parameter PARSER_SHOW_COUNT =
"解析结果展示个数", "前端展示的解析个数", "number", "Parser相关配置"); new Parameter("s2.parser.show.count", "3", "解析结果展示个数", "前端展示的解析个数", "number", "语义解析配置");
@Override @Override
public List<Parameter> getSysParameters() { public List<Parameter> getSysParameters() {

View File

@@ -75,7 +75,7 @@ public class LLMRequestService {
llmReq.setSqlGenType( llmReq.setSqlGenType(
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE))); LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setModelConfig(queryCtx.getModelConfig()); llmReq.setModelConfig(queryCtx.getModelConfig());
llmReq.setPromptConfig(queryCtx.getPromptConfig()); llmReq.setCustomPrompt(queryCtx.getCustomPrompt());
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars()); llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
return llmReq; return llmReq;

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.chat.parser.llm; package com.tencent.supersonic.headless.chat.parser.llm;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
@@ -112,10 +111,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
variable.put("information", sideInformation); variable.put("information", sideInformation);
// use custom prompt template if provided. // use custom prompt template if provided.
PromptConfig promptConfig = llmReq.getPromptConfig();
String promptTemplate = INSTRUCTION; String promptTemplate = INSTRUCTION;
if (promptConfig != null && StringUtils.isNotBlank(promptConfig.getPromptTemplate())) { if (StringUtils.isNotBlank(llmReq.getCustomPrompt())) {
promptTemplate = promptConfig.getPromptTemplate(); promptTemplate = llmReq.getCustomPrompt();
} }
return PromptTemplate.from(promptTemplate).apply(variable); return PromptTemplate.from(promptTemplate).apply(variable);
} }

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
@@ -23,7 +22,7 @@ public class LLMReq {
private String priorExts; private String priorExts;
private SqlGenType sqlGenType; private SqlGenType sqlGenType;
private ChatModelConfig modelConfig; private ChatModelConfig modelConfig;
private PromptConfig promptConfig; private String customPrompt;
private List<Text2SQLExemplar> dynamicExemplars; private List<Text2SQLExemplar> dynamicExemplars;
@Data @Data

View File

@@ -254,7 +254,7 @@ public class DataSetServiceImpl extends ServiceImpl<DataSetDOMapper, DataSetDO>
@Override @Override
public Map<Long, List<Long>> getModelIdToDataSetIds() { public Map<Long, List<Long>> getModelIdToDataSetIds() {
return getModelIdToDataSetIds(Lists.newArrayList(), User.getFakeUser()); return getModelIdToDataSetIds(Lists.newArrayList(), User.getDefaultUser());
} }
private void conflictCheck(DataSetResp dataSetResp) { private void conflictCheck(DataSetResp dataSetResp) {

View File

@@ -70,7 +70,7 @@ public class RetrieveServiceImpl implements RetrieveService {
List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics(); List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics();
final Map<Long, String> dataSetIdToName = semanticSchemaDb.getDataSetIdToName(); final Map<Long, String> dataSetIdToName = semanticSchemaDb.getDataSetIdToName();
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds( Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(
new ArrayList<>(dataSetIdToName.keySet()), User.getFakeUser()); new ArrayList<>(dataSetIdToName.keySet()), User.getDefaultUser());
// 2.detect by segment // 2.detect by segment
List<S2Term> originals = knowledgeBaseService.getTerms(queryText, modelIdToDataSetIds); List<S2Term> originals = knowledgeBaseService.getTerms(queryText, modelIdToDataSetIds);
log.debug("hanlp parse result: {}", originals); log.debug("hanlp parse result: {}", originals);

View File

@@ -162,7 +162,7 @@ public class DictUtils {
dictItemResp.setBizName(dimension.getBizName()); dictItemResp.setBizName(dimension.getBizName());
} }
if (TypeEnums.TAG.equals(TypeEnums.valueOf(dictConfDO.getType()))) { if (TypeEnums.TAG.equals(TypeEnums.valueOf(dictConfDO.getType()))) {
TagResp tagResp = tagMetaService.getTag(dictConfDO.getItemId(), User.getFakeUser()); TagResp tagResp = tagMetaService.getTag(dictConfDO.getItemId(), User.getDefaultUser());
dictItemResp.setModelId(tagResp.getModelId()); dictItemResp.setModelId(tagResp.getModelId());
dictItemResp.setBizName(tagResp.getBizName()); dictItemResp.setBizName(tagResp.getBizName());
} }

View File

@@ -43,7 +43,7 @@ public class MetricServiceImplTest {
MetricReq metricReq = buildMetricReq(); MetricReq metricReq = buildMetricReq();
when(modelService.getModel(metricReq.getModelId())).thenReturn(mockModelResp()); when(modelService.getModel(metricReq.getModelId())).thenReturn(mockModelResp());
when(modelService.getModelByDomainIds(any())).thenReturn(Lists.newArrayList()); when(modelService.getModelByDomainIds(any())).thenReturn(Lists.newArrayList());
MetricResp actualMetricResp = metricService.createMetric(metricReq, User.getFakeUser()); MetricResp actualMetricResp = metricService.createMetric(metricReq, User.getDefaultUser());
MetricResp expectedMetricResp = buildExpectedMetricResp(); MetricResp expectedMetricResp = buildExpectedMetricResp();
Assertions.assertEquals(expectedMetricResp, actualMetricResp); Assertions.assertEquals(expectedMetricResp, actualMetricResp);
} }
@@ -58,7 +58,7 @@ public class MetricServiceImplTest {
when(modelService.getModelByDomainIds(any())).thenReturn(Lists.newArrayList()); when(modelService.getModelByDomainIds(any())).thenReturn(Lists.newArrayList());
MetricDO metricDO = MetricConverter.convert2MetricDO(buildMetricReq()); MetricDO metricDO = MetricConverter.convert2MetricDO(buildMetricReq());
when(metricRepository.getMetricById(metricDO.getId())).thenReturn(metricDO); when(metricRepository.getMetricById(metricDO.getId())).thenReturn(metricDO);
MetricResp actualMetricResp = metricService.updateMetric(metricReq, User.getFakeUser()); MetricResp actualMetricResp = metricService.updateMetric(metricReq, User.getDefaultUser());
MetricResp expectedMetricResp = buildExpectedMetricResp(); MetricResp expectedMetricResp = buildExpectedMetricResp();
Assertions.assertEquals(expectedMetricResp, actualMetricResp); Assertions.assertEquals(expectedMetricResp, actualMetricResp);
} }

View File

@@ -34,7 +34,7 @@ class ModelServiceImplTest {
void createModel() throws Exception { void createModel() throws Exception {
ModelRepository modelRepository = Mockito.mock(ModelRepository.class); ModelRepository modelRepository = Mockito.mock(ModelRepository.class);
ModelService modelService = mockModelService(modelRepository); ModelService modelService = mockModelService(modelRepository);
ModelResp actualModelResp = modelService.createModel(mockModelReq(), User.getFakeUser()); ModelResp actualModelResp = modelService.createModel(mockModelReq(), User.getDefaultUser());
ModelResp expectedModelResp = buildExpectedModelResp(); ModelResp expectedModelResp = buildExpectedModelResp();
Assertions.assertEquals(expectedModelResp, actualModelResp); Assertions.assertEquals(expectedModelResp, actualModelResp);
} }
@@ -44,9 +44,9 @@ class ModelServiceImplTest {
ModelRepository modelRepository = Mockito.mock(ModelRepository.class); ModelRepository modelRepository = Mockito.mock(ModelRepository.class);
ModelService modelService = mockModelService(modelRepository); ModelService modelService = mockModelService(modelRepository);
ModelReq modelReq = mockModelReq_update(); ModelReq modelReq = mockModelReq_update();
ModelDO modelDO = ModelConverter.convert(mockModelReq(), User.getFakeUser()); ModelDO modelDO = ModelConverter.convert(mockModelReq(), User.getDefaultUser());
when(modelRepository.getModelById(modelReq.getId())).thenReturn(modelDO); when(modelRepository.getModelById(modelReq.getId())).thenReturn(modelDO);
User user = User.getFakeUser(); User user = User.getDefaultUser();
user.setName("alice"); user.setName("alice");
ModelResp actualModelResp = modelService.updateModel(modelReq, user); ModelResp actualModelResp = modelService.updateModel(modelReq, user);
ModelResp expectedModelResp = buildExpectedModelResp_update(); ModelResp expectedModelResp = buildExpectedModelResp_update();
@@ -60,9 +60,9 @@ class ModelServiceImplTest {
ModelRepository modelRepository = Mockito.mock(ModelRepository.class); ModelRepository modelRepository = Mockito.mock(ModelRepository.class);
ModelService modelService = mockModelService(modelRepository); ModelService modelService = mockModelService(modelRepository);
ModelReq modelReq = mockModelReq_updateAdmin(); ModelReq modelReq = mockModelReq_updateAdmin();
ModelDO modelDO = ModelConverter.convert(mockModelReq(), User.getFakeUser()); ModelDO modelDO = ModelConverter.convert(mockModelReq(), User.getDefaultUser());
when(modelRepository.getModelById(modelReq.getId())).thenReturn(modelDO); when(modelRepository.getModelById(modelReq.getId())).thenReturn(modelDO);
ModelResp actualModelResp = modelService.updateModel(modelReq, User.getFakeUser()); ModelResp actualModelResp = modelService.updateModel(modelReq, User.getDefaultUser());
ModelResp expectedModelResp = buildExpectedModelResp(); ModelResp expectedModelResp = buildExpectedModelResp();
Assertions.assertEquals(expectedModelResp, actualModelResp); Assertions.assertEquals(expectedModelResp, actualModelResp);
} }

View File

@@ -224,7 +224,7 @@ public class CspiderDemo extends S2BaseDemo {
queryConfig.setDetailTypeDefaultConfig(detailTypeDefaultConfig); queryConfig.setDetailTypeDefaultConfig(detailTypeDefaultConfig);
queryConfig.setAggregateTypeDefaultConfig(aggregateTypeDefaultConfig); queryConfig.setAggregateTypeDefaultConfig(aggregateTypeDefaultConfig);
dataSetReq.setQueryConfig(queryConfig); dataSetReq.setQueryConfig(queryConfig);
dataSetService.save(dataSetReq, User.getFakeUser()); dataSetService.save(dataSetReq, User.getDefaultUser());
} }
public void addModelRela_1(DomainResp s2Domain, ModelResp genreModelResp, public void addModelRela_1(DomainResp s2Domain, ModelResp genreModelResp,
@@ -296,6 +296,6 @@ public class CspiderDemo extends S2BaseDemo {
private void batchPushlishMetric() { private void batchPushlishMetric() {
List<Long> ids = Lists.newArrayList(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L); List<Long> ids = Lists.newArrayList(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L);
metricService.batchPublish(ids, User.getFakeUser()); metricService.batchPublish(ids, User.getDefaultUser());
} }
} }

View File

@@ -4,9 +4,9 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool; import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela; import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
@@ -274,7 +274,7 @@ public class DuSQLDemo extends S2BaseDemo {
aggregateTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig); aggregateTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig);
queryConfig.setAggregateTypeDefaultConfig(aggregateTypeDefaultConfig); queryConfig.setAggregateTypeDefaultConfig(aggregateTypeDefaultConfig);
dataSetReq.setQueryConfig(queryConfig); dataSetReq.setQueryConfig(queryConfig);
dataSetService.save(dataSetReq, User.getFakeUser()); dataSetService.save(dataSetReq, User.getDefaultUser());
} }
public void addModelRela_1() { public void addModelRela_1() {
@@ -334,16 +334,16 @@ public class DuSQLDemo extends S2BaseDemo {
agent.setStatus(1); agent.setStatus(1);
agent.setEnableSearch(1); agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList()); agent.setExamples(Lists.newArrayList());
AgentConfig agentConfig = new AgentConfig(); ToolConfig toolConfig = new ToolConfig();
LLMParserTool llmParserTool = new LLMParserTool(); LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1"); llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM); llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(4L)); llmParserTool.setDataSetIds(Lists.newArrayList(4L));
agentConfig.getTools().add(llmParserTool); toolConfig.getTools().add(llmParserTool);
agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
log.info("agent:{}", JsonUtil.toString(agent)); log.info("agent:{}", JsonUtil.toString(agent));
agentService.createAgent(agent, User.getFakeUser()); agentService.createAgent(agent, User.getDefaultUser());
} }
} }

View File

@@ -4,10 +4,10 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool; import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.RuleParserTool; import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.headless.api.pojo.*; import com.tencent.supersonic.headless.api.pojo.*;
@@ -69,7 +69,7 @@ public class S2ArtistDemo extends S2BaseDemo {
tagObjectReq.setDomainId(singerDomain.getId()); tagObjectReq.setDomainId(singerDomain.getId());
tagObjectReq.setName("歌手"); tagObjectReq.setName("歌手");
tagObjectReq.setBizName("singer"); tagObjectReq.setBizName("singer");
User user = User.getFakeUser(); User user = User.getDefaultUser();
return tagObjectService.create(tagObjectReq, user); return tagObjectService.create(tagObjectReq, user);
} }
@@ -159,7 +159,7 @@ public class S2ArtistDemo extends S2BaseDemo {
queryConfig.setDetailTypeDefaultConfig(detailTypeDefaultConfig); queryConfig.setDetailTypeDefaultConfig(detailTypeDefaultConfig);
queryConfig.setAggregateTypeDefaultConfig(aggregateTypeDefaultConfig); queryConfig.setAggregateTypeDefaultConfig(aggregateTypeDefaultConfig);
dataSetReq.setQueryConfig(queryConfig); dataSetReq.setQueryConfig(queryConfig);
DataSetResp dataSetResp = dataSetService.save(dataSetReq, User.getFakeUser()); DataSetResp dataSetResp = dataSetService.save(dataSetReq, User.getDefaultUser());
return dataSetResp.getId(); return dataSetResp.getId();
} }
@@ -170,21 +170,21 @@ public class S2ArtistDemo extends S2BaseDemo {
agent.setStatus(1); agent.setStatus(1);
agent.setEnableSearch(1); agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派")); agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派"));
AgentConfig agentConfig = new AgentConfig(); ToolConfig toolConfig = new ToolConfig();
RuleParserTool ruleQueryTool = new RuleParserTool(); RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setId("0"); ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId)); ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
agentConfig.getTools().add(ruleQueryTool); toolConfig.getTools().add(ruleQueryTool);
if (demoEnableLlm) { if (demoEnableLlm) {
LLMParserTool llmParserTool = new LLMParserTool(); LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1"); llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM); llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId)); llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
agentConfig.getTools().add(llmParserTool); toolConfig.getTools().add(llmParserTool);
} }
agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
agentService.createAgent(agent, User.getFakeUser()); agentService.createAgent(agent, User.getDefaultUser());
} }
} }

View File

@@ -3,10 +3,8 @@ package com.tencent.supersonic.demo;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authorization.service.AuthService; import com.tencent.supersonic.auth.api.authorization.service.AuthService;
import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.service.*;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.service.SystemConfigService; import com.tencent.supersonic.common.service.SystemConfigService;
import com.tencent.supersonic.common.util.AESEncryptionUtil; import com.tencent.supersonic.common.util.AESEncryptionUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig; import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
@@ -33,6 +31,7 @@ import com.tencent.supersonic.headless.server.service.TagMetaService;
import com.tencent.supersonic.headless.server.service.TagObjectService; import com.tencent.supersonic.headless.server.service.TagObjectService;
import com.tencent.supersonic.headless.server.service.TermService; import com.tencent.supersonic.headless.server.service.TermService;
import com.tencent.supersonic.headless.server.service.impl.DictWordService; import com.tencent.supersonic.headless.server.service.impl.DictWordService;
import dev.langchain4j.provider.ModelProvider;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
@@ -47,8 +46,9 @@ import java.util.stream.Collectors;
@Slf4j @Slf4j
public abstract class S2BaseDemo implements CommandLineRunner { public abstract class S2BaseDemo implements CommandLineRunner {
protected DatabaseResp demoDatabaseResp; protected DatabaseResp demoDatabaseResp;
protected ChatModel chatModel;
protected User user = User.getFakeUser(); protected User user = User.getDefaultUser();
@Autowired @Autowired
protected DatabaseService databaseService; protected DatabaseService databaseService;
@Autowired @Autowired
@@ -87,6 +87,8 @@ public abstract class S2BaseDemo implements CommandLineRunner {
protected CanvasService canvasService; protected CanvasService canvasService;
@Autowired @Autowired
protected DictWordService dictWordService; protected DictWordService dictWordService;
@Autowired
protected ChatModelService chatModelService;
@Value("${s2.demo.names:S2VisitsDemo}") @Value("${s2.demo.names:S2VisitsDemo}")
protected List<String> demoList; protected List<String> demoList;
@@ -96,6 +98,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
public void run(String... args) { public void run(String... args) {
demoDatabaseResp = addDatabaseIfNotExist(); demoDatabaseResp = addDatabaseIfNotExist();
addChatModelIfNotExist();
if (demoList != null && demoList.contains(getClass().getSimpleName())) { if (demoList != null && demoList.contains(getClass().getSimpleName())) {
if (checkNeedToRun()) { if (checkNeedToRun()) {
doRun(); doRun();
@@ -108,7 +111,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
abstract boolean checkNeedToRun(); abstract boolean checkNeedToRun();
protected DatabaseResp addDatabaseIfNotExist() { protected DatabaseResp addDatabaseIfNotExist() {
List<DatabaseResp> databaseList = databaseService.getDatabaseList(User.getFakeUser()); List<DatabaseResp> databaseList = databaseService.getDatabaseList(User.getDefaultUser());
if (!CollectionUtils.isEmpty(databaseList)) { if (!CollectionUtils.isEmpty(databaseList)) {
return databaseList.get(0); return databaseList.get(0);
} }
@@ -130,6 +133,17 @@ public abstract class S2BaseDemo implements CommandLineRunner {
return databaseService.createOrUpdateDatabase(databaseReq, user); return databaseService.createOrUpdateDatabase(databaseReq, user);
} }
protected void addChatModelIfNotExist() {
if (chatModelService.getChatModels().size() > 0) {
return;
}
chatModel = new ChatModel();
chatModel.setName("OpenAI模型DEMO");
chatModel.setDescription("由langchain4j社区提供仅用于体验单次请求最大token数1000");
chatModel.setConfig(ModelProvider.DEMO_CHAT_MODEL);
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
}
protected MetricResp getMetric(String bizName, ModelResp model) { protected MetricResp getMetric(String bizName, ModelResp model) {
return metricService.getMetric(model.getId(), bizName); return metricService.getMetric(model.getId(), bizName);
} }
@@ -160,7 +174,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
TagReq tagReq = new TagReq(); TagReq tagReq = new TagReq();
tagReq.setTagDefineType(tagDefineType); tagReq.setTagDefineType(tagDefineType);
tagReq.setItemId(itemId); tagReq.setItemId(itemId);
tagMetaService.create(tagReq, User.getFakeUser()); tagMetaService.create(tagReq, User.getDefaultUser());
} }
protected DimensionResp getDimension(String bizName, ModelResp model) { protected DimensionResp getDimension(String bizName, ModelResp model) {

View File

@@ -2,16 +2,17 @@ package com.tencent.supersonic.demo;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup; import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule; import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool; import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool; import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.server.plugin.build.WebBase; import com.tencent.supersonic.chat.server.plugin.build.WebBase;
@@ -19,12 +20,7 @@ import com.tencent.supersonic.chat.server.plugin.build.webpage.WebPageQuery;
import com.tencent.supersonic.chat.server.plugin.build.webservice.WebServiceQuery; import com.tencent.supersonic.chat.server.plugin.build.webservice.WebServiceQuery;
import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela; import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.*;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail; import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig; import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
@@ -63,10 +59,7 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.core.annotation.Order; import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import java.util.ArrayList; import java.util.*;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@Component @Component
@Slf4j @Slf4j
@@ -146,7 +139,7 @@ public class S2VisitsDemo extends S2BaseDemo {
private void submitText(int chatId, int agentId, String queryText) { private void submitText(int chatId, int agentId, String queryText) {
chatQueryService.parseAndExecute(ChatParseReq.builder().chatId(chatId).agentId(agentId) chatQueryService.parseAndExecute(ChatParseReq.builder().chatId(chatId).agentId(agentId)
.queryText(queryText).user(User.getFakeUser()).disableLLM(true).build()); .queryText(queryText).user(User.getDefaultUser()).disableLLM(true).build());
} }
private Integer addAgent(long dataSetId) { private Integer addAgent(long dataSetId) {
@@ -157,23 +150,32 @@ public class S2VisitsDemo extends S2BaseDemo {
agent.setEnableSearch(1); agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("近15天超音数访问次数汇总", "按部门统计超音数的访问人数", "对比alice和lucy的停留时长", agent.setExamples(Lists.newArrayList("近15天超音数访问次数汇总", "按部门统计超音数的访问人数", "对比alice和lucy的停留时长",
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长")); "过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长"));
AgentConfig agentConfig = new AgentConfig(); // configure tools
ToolConfig toolConfig = new ToolConfig();
RuleParserTool ruleQueryTool = new RuleParserTool(); RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0"); ruleQueryTool.setId("0");
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId)); ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
agentConfig.getTools().add(ruleQueryTool); toolConfig.getTools().add(ruleQueryTool);
if (demoEnableLlm) { if (demoEnableLlm) {
LLMParserTool llmParserTool = new LLMParserTool(); LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1"); llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM); llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId)); llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
agentConfig.getTools().add(llmParserTool); toolConfig.getTools().add(llmParserTool);
} }
agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// configure chat models
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap();
chatModelConfig.put(ChatModelType.TEXT_TO_SQL, chatModel.getId());
chatModelConfig.put(ChatModelType.MEMORY_REVIEW, chatModel.getId());
chatModelConfig.put(ChatModelType.RESPONSE_GENERATE, chatModel.getId());
chatModelConfig.put(ChatModelType.MULTI_TURN_REWRITE, chatModel.getId());
agent.setChatModelConfig(chatModelConfig);
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(true); MultiTurnConfig multiTurnConfig = new MultiTurnConfig(true);
agent.setMultiTurnConfig(multiTurnConfig); agent.setMultiTurnConfig(multiTurnConfig);
Agent agentCreated = agentService.createAgent(agent, User.getFakeUser()); Agent agentCreated = agentService.createAgent(agent, User.getDefaultUser());
return agentCreated.getId(); return agentCreated.getId();
} }
@@ -460,7 +462,7 @@ public class S2VisitsDemo extends S2BaseDemo {
dataSetDetail.setDataSetModelConfigs(dataSetModelConfigs); dataSetDetail.setDataSetModelConfigs(dataSetModelConfigs);
dataSetReq.setDataSetDetail(dataSetDetail); dataSetReq.setDataSetDetail(dataSetDetail);
dataSetReq.setTypeEnum(TypeEnums.DATASET); dataSetReq.setTypeEnum(TypeEnums.DATASET);
return dataSetService.save(dataSetReq, User.getFakeUser()); return dataSetService.save(dataSetReq, User.getDefaultUser());
} }
public void addTerm(DomainResp s2Domain) { public void addTerm(DomainResp s2Domain) {
@@ -469,7 +471,7 @@ public class S2VisitsDemo extends S2BaseDemo {
termReq.setDescription("指近10天"); termReq.setDescription("指近10天");
termReq.setAlias(Lists.newArrayList("近一段时间")); termReq.setAlias(Lists.newArrayList("近一段时间"));
termReq.setDomainId(s2Domain.getId()); termReq.setDomainId(s2Domain.getId());
termService.saveOrUpdate(termReq, User.getFakeUser()); termService.saveOrUpdate(termReq, User.getDefaultUser());
} }
public void addTerm_1(DomainResp s2Domain) { public void addTerm_1(DomainResp s2Domain) {
@@ -478,7 +480,7 @@ public class S2VisitsDemo extends S2BaseDemo {
termReq.setDescription("用户为tom和lucy"); termReq.setDescription("用户为tom和lucy");
termReq.setAlias(Lists.newArrayList("VIP用户")); termReq.setAlias(Lists.newArrayList("VIP用户"));
termReq.setDomainId(s2Domain.getId()); termReq.setDomainId(s2Domain.getId());
termService.saveOrUpdate(termReq, User.getFakeUser()); termService.saveOrUpdate(termReq, User.getDefaultUser());
} }
public void addAuthGroup_1(ModelResp stayTimeModel) { public void addAuthGroup_1(ModelResp stayTimeModel) {
@@ -553,7 +555,7 @@ public class S2VisitsDemo extends S2BaseDemo {
tagObjectReq.setDomainId(s2Domain.getId()); tagObjectReq.setDomainId(s2Domain.getId());
tagObjectReq.setName("用户"); tagObjectReq.setName("用户");
tagObjectReq.setBizName("user"); tagObjectReq.setBizName("user");
User user = User.getFakeUser(); User user = User.getDefaultUser();
return tagObjectService.create(tagObjectReq, user); return tagObjectService.create(tagObjectReq, user);
} }

View File

@@ -4,8 +4,8 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.core.annotation.Order; import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
@@ -24,14 +24,14 @@ public class SmallTalkDemo extends S2BaseDemo {
agent.setDescription("直接与大模型对话,验证连通性"); agent.setDescription("直接与大模型对话,验证连通性");
agent.setStatus(1); agent.setStatus(1);
agent.setEnableSearch(0); agent.setEnableSearch(0);
AgentConfig agentConfig = new AgentConfig(); ToolConfig toolConfig = new ToolConfig();
agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
agent.setExamples(Lists.newArrayList("如何才能变帅", "如何才能赚更多钱", "如何才能世界和平")); agent.setExamples(Lists.newArrayList("如何才能变帅", "如何才能赚更多钱", "如何才能世界和平"));
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(); MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
multiTurnConfig.setEnableMultiTurn(true); multiTurnConfig.setEnableMultiTurn(true);
agent.setMultiTurnConfig(multiTurnConfig); agent.setMultiTurnConfig(multiTurnConfig);
agentService.createAgent(agent, User.getFakeUser()); agentService.createAgent(agent, User.getDefaultUser());
} }
@Override @Override

View File

@@ -370,3 +370,20 @@ alter table singer drop column imp_date;
--20240913 --20240913
ALTER TABLE s2_model MODIFY COLUMN drill_down_dimensions TEXT DEFAULT NULL; ALTER TABLE s2_model MODIFY COLUMN drill_down_dimensions TEXT DEFAULT NULL;
--20241009
CREATE TABLE IF NOT EXISTS `s2_chat_model` (
`id` bigint(20) NOT NULL AUTO_INCREMENT,
`name` varchar(255) NOT NULL COMMENT '名称',
`description` varchar(500) DEFAULT NULL COMMENT '描述',
`config` text NOT NULL COMMENT '配置信息',
`created_at` datetime NOT NULL COMMENT '创建时间',
`created_by` varchar(100) NOT NULL COMMENT '创建人',
`updated_at` datetime NOT NULL COMMENT '更新时间',
`updated_by` varchar(100) NOT NULL COMMENT '更新人',
`admin` varchar(500) DEFAULT NULL,
`viewer` varchar(500) DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='对话大模型实例表';
ALTER TABLE s2_agent RENAME COLUMN config TO tool_config;
ALTER TABLE s2_agent RENAME COLUMN model_config TO chat_model_config;

View File

@@ -388,9 +388,9 @@ CREATE TABLE IF NOT EXISTS s2_agent
description varchar(500) null, description varchar(500) null,
status int null, status int null,
examples varchar(500) null, examples varchar(500) null,
config varchar(2000) null, tool_config varchar(2000) null,
llm_config varchar(2000) null, llm_config varchar(2000) null,
model_config varchar(6000) null, chat_model_config varchar(6000) null,
prompt_config varchar(5000) null, prompt_config varchar(5000) null,
multi_turn_config varchar(2000) null, multi_turn_config varchar(2000) null,
visual_config varchar(2000) null, visual_config varchar(2000) null,

View File

@@ -70,9 +70,9 @@ CREATE TABLE IF NOT EXISTS `s2_agent` (
`examples` TEXT COLLATE utf8_unicode_ci DEFAULT NULL, `examples` TEXT COLLATE utf8_unicode_ci DEFAULT NULL,
`status` tinyint DEFAULT NULL, `status` tinyint DEFAULT NULL,
`model` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL, `model` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
`config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL, `tool_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL,
`llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`model_config` text COLLATE utf8_unicode_ci DEFAULT NULL, `chat_model_config` text COLLATE utf8_unicode_ci DEFAULT NULL,
`prompt_config` text COLLATE utf8_unicode_ci DEFAULT NULL, `prompt_config` text COLLATE utf8_unicode_ci DEFAULT NULL,
`multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,

View File

@@ -7,10 +7,10 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.BaseTest; import com.tencent.supersonic.chat.BaseTest;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool; import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.chat.server.pojo.ChatModel; import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
@@ -49,7 +49,7 @@ public class Text2SQLEval extends BaseTest {
QueryResult result = submitNewChat("近30天总访问次数", agentId); QueryResult result = submitNewChat("近30天总访问次数", agentId);
durations.add(System.currentTimeMillis() - start); durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 1; assert result.getQueryColumns().size() == 1;
assert result.getQueryColumns().get(0).getName().contains("访问次数"); assert result.getTextResult().contains("511");
} }
@Test @Test
@@ -58,8 +58,8 @@ public class Text2SQLEval extends BaseTest {
QueryResult result = submitNewChat("近30日每天的访问次数", agentId); QueryResult result = submitNewChat("近30日每天的访问次数", agentId);
durations.add(System.currentTimeMillis() - start); durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 2; assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("date"); assert result.getQueryResults().size() == 30;
assert result.getQueryColumns().get(1).getName().contains("访问次数"); assert result.getTextResult().contains("date");
} }
@Test @Test
@@ -68,9 +68,11 @@ public class Text2SQLEval extends BaseTest {
QueryResult result = submitNewChat("过去30天每个部门的汇总访问次数", agentId); QueryResult result = submitNewChat("过去30天每个部门的汇总访问次数", agentId);
durations.add(System.currentTimeMillis() - start); durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() == 2; assert result.getQueryColumns().size() == 2;
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
assert result.getQueryColumns().get(1).getName().contains("访问次数");
assert result.getQueryResults().size() == 4; assert result.getQueryResults().size() == 4;
assert result.getTextResult().contains("marketing");
assert result.getTextResult().contains("sales");
assert result.getTextResult().contains("strategy");
assert result.getTextResult().contains("HR");
} }
@Test @Test
@@ -134,16 +136,16 @@ public class Text2SQLEval extends BaseTest {
public Agent getLLMAgent(boolean enableMultiturn) { public Agent getLLMAgent(boolean enableMultiturn) {
Agent agent = new Agent(); Agent agent = new Agent();
agent.setName("Agent for Test"); agent.setName("Agent for Test");
AgentConfig agentConfig = new AgentConfig(); ToolConfig toolConfig = new ToolConfig();
agentConfig.getTools().add(getLLMQueryTool()); toolConfig.getTools().add(getLLMQueryTool());
agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
ChatModel chatModel = new ChatModel(); ChatModel chatModel = new ChatModel();
chatModel.setName("Text2SQL LLM"); chatModel.setName("Text2SQL LLM");
chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3)); chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
chatModel = chatModelService.createChatModel(chatModel, User.getFakeUser()); chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap(); Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap();
chatModelConfig.put(ChatModelType.TEXT_TO_SQL, chatModel.getId()); chatModelConfig.put(ChatModelType.TEXT_TO_SQL, chatModel.getId());
agent.setModelConfig(chatModelConfig); agent.setChatModelConfig(chatModelConfig);
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(); MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
multiTurnConfig.setEnableMultiTurn(enableMultiturn); multiTurnConfig.setEnableMultiTurn(enableMultiturn);
agent.setMultiTurnConfig(multiTurnConfig); agent.setMultiTurnConfig(multiTurnConfig);

View File

@@ -34,7 +34,7 @@ public class BaseTest extends BaseApplication {
private DomainRepository domainRepository; private DomainRepository domainRepository;
protected SemanticQueryResp queryBySql(String sql) throws Exception { protected SemanticQueryResp queryBySql(String sql) throws Exception {
return queryBySql(sql, User.getFakeUser()); return queryBySql(sql, User.getDefaultUser());
} }
protected SemanticQueryResp queryBySql(String sql, User user) throws Exception { protected SemanticQueryResp queryBySql(String sql, User user) throws Exception {

View File

@@ -22,7 +22,7 @@ public class MetaDiscoveryTest extends BaseTest {
QueryMapReq queryMapReq = new QueryMapReq(); QueryMapReq queryMapReq = new QueryMapReq();
queryMapReq.setQueryText("对比alice和lucy的访问次数"); queryMapReq.setQueryText("对比alice和lucy的访问次数");
queryMapReq.setTopN(10); queryMapReq.setTopN(10);
queryMapReq.setUser(User.getFakeUser()); queryMapReq.setUser(User.getDefaultUser());
queryMapReq.setDataSetNames(Collections.singletonList("超音数数据集")); queryMapReq.setDataSetNames(Collections.singletonList("超音数数据集"));
MapInfoResp mapMeta = chatLayerService.map(queryMapReq); MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
@@ -36,7 +36,7 @@ public class MetaDiscoveryTest extends BaseTest {
QueryMapReq queryMapReq = new QueryMapReq(); QueryMapReq queryMapReq = new QueryMapReq();
queryMapReq.setQueryText("风格为流行的艺人"); queryMapReq.setQueryText("风格为流行的艺人");
queryMapReq.setTopN(10); queryMapReq.setTopN(10);
queryMapReq.setUser(User.getFakeUser()); queryMapReq.setUser(User.getDefaultUser());
queryMapReq.setDataSetNames(Collections.singletonList("艺人库")); queryMapReq.setDataSetNames(Collections.singletonList("艺人库"));
queryMapReq.setQueryDataType(QueryDataType.TAG); queryMapReq.setQueryDataType(QueryDataType.TAG);
MapInfoResp mapMeta = chatLayerService.map(queryMapReq); MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
@@ -48,7 +48,7 @@ public class MetaDiscoveryTest extends BaseTest {
QueryMapReq queryMapReq = new QueryMapReq(); QueryMapReq queryMapReq = new QueryMapReq();
queryMapReq.setQueryText("超音数访问次数最高的部门"); queryMapReq.setQueryText("超音数访问次数最高的部门");
queryMapReq.setTopN(10); queryMapReq.setTopN(10);
queryMapReq.setUser(User.getFakeUser()); queryMapReq.setUser(User.getDefaultUser());
queryMapReq.setDataSetNames(Collections.singletonList("超音数")); queryMapReq.setDataSetNames(Collections.singletonList("超音数"));
queryMapReq.setQueryDataType(QueryDataType.METRIC); queryMapReq.setQueryDataType(QueryDataType.METRIC);
MapInfoResp mapMeta = chatLayerService.map(queryMapReq); MapInfoResp mapMeta = chatLayerService.map(queryMapReq);

View File

@@ -23,7 +23,7 @@ public class QueryByMetricTest extends BaseTest {
QueryMetricReq queryMetricReq = new QueryMetricReq(); QueryMetricReq queryMetricReq = new QueryMetricReq();
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv")); queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department")); queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser()); SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList()); Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size()); Assert.assertEquals(6, queryResp.getResultList().size());
} }
@@ -33,7 +33,7 @@ public class QueryByMetricTest extends BaseTest {
QueryMetricReq queryMetricReq = new QueryMetricReq(); QueryMetricReq queryMetricReq = new QueryMetricReq();
queryMetricReq.setMetricNames(Arrays.asList("停留时长", "访问次数")); queryMetricReq.setMetricNames(Arrays.asList("停留时长", "访问次数"));
queryMetricReq.setDimensionNames(Arrays.asList("用户", "部门")); queryMetricReq.setDimensionNames(Arrays.asList("用户", "部门"));
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser()); SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList()); Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size()); Assert.assertEquals(6, queryResp.getResultList().size());
} }
@@ -44,7 +44,7 @@ public class QueryByMetricTest extends BaseTest {
queryMetricReq.setDomainId(1L); queryMetricReq.setDomainId(1L);
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv")); queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department")); queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser()); SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList()); Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size()); Assert.assertEquals(6, queryResp.getResultList().size());
@@ -52,7 +52,7 @@ public class QueryByMetricTest extends BaseTest {
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv")); queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department")); queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
assertThrows(IllegalArgumentException.class, assertThrows(IllegalArgumentException.class,
() -> queryByMetric(queryMetricReq, User.getFakeUser())); () -> queryByMetric(queryMetricReq, User.getDefaultUser()));
} }
@Test @Test
@@ -61,7 +61,7 @@ public class QueryByMetricTest extends BaseTest {
queryMetricReq.setDomainId(1L); queryMetricReq.setDomainId(1L);
queryMetricReq.setMetricIds(Arrays.asList(1L, 3L)); queryMetricReq.setMetricIds(Arrays.asList(1L, 3L));
queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L)); queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L));
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser()); SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList()); Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size()); Assert.assertEquals(6, queryResp.getResultList().size());
} }

View File

@@ -49,7 +49,7 @@ public class QueryByStructTest extends BaseTest {
QueryStructReq queryStructReq = QueryStructReq queryStructReq =
buildQueryStructReq(Arrays.asList("user_name", "department"), QueryType.DETAIL); buildQueryStructReq(Arrays.asList("user_name", "department"), QueryType.DETAIL);
SemanticQueryResp semanticQueryResp = SemanticQueryResp semanticQueryResp =
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser()); semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
assertEquals(3, semanticQueryResp.getColumns().size()); assertEquals(3, semanticQueryResp.getColumns().size());
QueryColumn firstColumn = semanticQueryResp.getColumns().get(0); QueryColumn firstColumn = semanticQueryResp.getColumns().get(0);
assertEquals("用户", firstColumn.getName()); assertEquals("用户", firstColumn.getName());
@@ -64,7 +64,7 @@ public class QueryByStructTest extends BaseTest {
public void testSumQuery() throws Exception { public void testSumQuery() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(null); QueryStructReq queryStructReq = buildQueryStructReq(null);
SemanticQueryResp semanticQueryResp = SemanticQueryResp semanticQueryResp =
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser()); semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
assertEquals(1, semanticQueryResp.getColumns().size()); assertEquals(1, semanticQueryResp.getColumns().size());
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0); QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
assertEquals("访问次数", queryColumn.getName()); assertEquals("访问次数", queryColumn.getName());
@@ -75,7 +75,7 @@ public class QueryByStructTest extends BaseTest {
public void testGroupByQuery() throws Exception { public void testGroupByQuery() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department")); QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
SemanticQueryResp result = SemanticQueryResp result =
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser()); semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
assertEquals(2, result.getColumns().size()); assertEquals(2, result.getColumns().size());
QueryColumn firstColumn = result.getColumns().get(0); QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1); QueryColumn secondColumn = result.getColumns().get(1);
@@ -97,7 +97,7 @@ public class QueryByStructTest extends BaseTest {
queryStructReq.setDimensionFilters(dimensionFilters); queryStructReq.setDimensionFilters(dimensionFilters);
SemanticQueryResp result = SemanticQueryResp result =
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser()); semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
assertEquals(2, result.getColumns().size()); assertEquals(2, result.getColumns().size());
QueryColumn firstColumn = result.getColumns().get(0); QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1); QueryColumn secondColumn = result.getColumns().get(1);

View File

@@ -15,7 +15,7 @@ public class QueryDimensionTest extends BaseTest {
queryDimValueReq.setBizName("department"); queryDimValueReq.setBizName("department");
SemanticQueryResp queryResp = SemanticQueryResp queryResp =
semanticLayerService.queryDimensionValue(queryDimValueReq, User.getFakeUser()); semanticLayerService.queryDimensionValue(queryDimValueReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList()); Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(4, queryResp.getResultList().size()); Assert.assertEquals(4, queryResp.getResultList().size());
} }

View File

@@ -22,7 +22,7 @@ public class QueryRuleTest extends BaseTest {
@Autowired @Autowired
private QueryRuleService queryRuleService; private QueryRuleService queryRuleService;
private User user = User.getFakeUser(); private User user = User.getDefaultUser();
public QueryRuleReq addSystemRule() { public QueryRuleReq addSystemRule() {
QueryRuleReq queryRuleReq = new QueryRuleReq(); QueryRuleReq queryRuleReq = new QueryRuleReq();

View File

@@ -18,7 +18,7 @@ public class TagObjectTest extends BaseTest {
@Test @Test
void testCreateTagObject() throws Exception { void testCreateTagObject() throws Exception {
User user = User.getFakeUser(); User user = User.getDefaultUser();
TagObjectReq tagObjectReq = newTagObjectReq(); TagObjectReq tagObjectReq = newTagObjectReq();
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, user); TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, user);
tagObjectService.delete(tagObjectResp.getId(), user, false); tagObjectService.delete(tagObjectResp.getId(), user, false);
@@ -27,24 +27,25 @@ public class TagObjectTest extends BaseTest {
@Test @Test
void testUpdateTagObject() throws Exception { void testUpdateTagObject() throws Exception {
TagObjectReq tagObjectReq = newTagObjectReq(); TagObjectReq tagObjectReq = newTagObjectReq();
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getFakeUser()); TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getDefaultUser());
TagObjectReq tagObjectReqUpdate = new TagObjectReq(); TagObjectReq tagObjectReqUpdate = new TagObjectReq();
BeanUtils.copyProperties(tagObjectResp, tagObjectReqUpdate); BeanUtils.copyProperties(tagObjectResp, tagObjectReqUpdate);
tagObjectReqUpdate.setName("艺人1"); tagObjectReqUpdate.setName("艺人1");
tagObjectService.update(tagObjectReqUpdate, User.getFakeUser()); tagObjectService.update(tagObjectReqUpdate, User.getDefaultUser());
TagObjectResp tagObject = TagObjectResp tagObject =
tagObjectService.getTagObject(tagObjectReqUpdate.getId(), User.getFakeUser()); tagObjectService.getTagObject(tagObjectReqUpdate.getId(), User.getDefaultUser());
tagObjectService.delete(tagObject.getId(), User.getFakeUser(), false); tagObjectService.delete(tagObject.getId(), User.getDefaultUser(), false);
} }
@Test @Test
void testQueryTagObject() throws Exception { void testQueryTagObject() throws Exception {
TagObjectReq tagObjectReq = newTagObjectReq(); TagObjectReq tagObjectReq = newTagObjectReq();
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getFakeUser()); TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getDefaultUser());
TagObjectFilter filter = new TagObjectFilter(); TagObjectFilter filter = new TagObjectFilter();
List<TagObjectResp> tagObjects = tagObjectService.getTagObjects(filter, User.getFakeUser()); List<TagObjectResp> tagObjects =
tagObjectService.getTagObjects(filter, User.getDefaultUser());
tagObjects.size(); tagObjects.size();
tagObjectService.delete(tagObjectResp.getId(), User.getFakeUser(), false); tagObjectService.delete(tagObjectResp.getId(), User.getDefaultUser(), false);
} }
private TagObjectReq newTagObjectReq() { private TagObjectReq newTagObjectReq() {

View File

@@ -21,7 +21,7 @@ public class TagTest extends BaseTest {
ItemValueReq itemValueReq = new ItemValueReq(); ItemValueReq itemValueReq = new ItemValueReq();
itemValueReq.setId(1L); itemValueReq.setId(1L);
ItemValueResp itemValueResp = ItemValueResp itemValueResp =
tagQueryService.queryTagValue(itemValueReq, User.getFakeUser()); tagQueryService.queryTagValue(itemValueReq, User.getDefaultUser());
Assertions.assertNotNull(itemValueResp); Assertions.assertNotNull(itemValueResp);
} }
} }

View File

@@ -19,7 +19,7 @@ public class TranslateTest extends BaseTest {
String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "; String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
SemanticTranslateResp explain = semanticLayerService.translate( SemanticTranslateResp explain = semanticLayerService.translate(
QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()), QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()),
User.getFakeUser()); User.getDefaultUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getQuerySQL()); assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department")); assertTrue(explain.getQuerySQL().contains("department"));
@@ -30,7 +30,7 @@ public class TranslateTest extends BaseTest {
public void testStructExplain() throws Exception { public void testStructExplain() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department")); QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
SemanticTranslateResp explain = SemanticTranslateResp explain =
semanticLayerService.translate(queryStructReq, User.getFakeUser()); semanticLayerService.translate(queryStructReq, User.getDefaultUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getQuerySQL()); assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department")); assertTrue(explain.getQuerySQL().contains("department"));

View File

@@ -19,7 +19,7 @@ public class DataUtils {
public static final Integer tagAgentId = 2; public static final Integer tagAgentId = 2;
public static final Integer ONE_TURNS_CHAT_ID = 10; public static final Integer ONE_TURNS_CHAT_ID = 10;
public static final Integer MULTI_TURNS_CHAT_ID = 11; public static final Integer MULTI_TURNS_CHAT_ID = 11;
private static final User user_test = User.getFakeUser(); private static final User user_test = User.getDefaultUser();
public static User getUser() { public static User getUser() {
return user_test; return user_test;

View File

@@ -388,9 +388,9 @@ CREATE TABLE IF NOT EXISTS s2_agent
description varchar(500) null, description varchar(500) null,
status int null, status int null,
examples varchar(500) null, examples varchar(500) null,
config varchar(2000) null, tool_config varchar(2000) null,
llm_config varchar(2000) null, llm_config varchar(2000) null,
model_config varchar(6000) null, chat_model_config varchar(6000) null,
prompt_config varchar(5000) null, prompt_config varchar(5000) null,
multi_turn_config varchar(2000) null, multi_turn_config varchar(2000) null,
visual_config varchar(2000) null, visual_config varchar(2000) null,