[feature][headless-chat]Introduce ChatApp to support more flexible chat model config.#1739

This commit is contained in:
jerryjzhang
2024-10-12 15:05:47 +08:00
parent fc94a6718b
commit 7c76c69ac0
9 changed files with 25 additions and 21 deletions

View File

@@ -29,14 +29,14 @@ import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULT
public class PlainTextExecutor implements ChatQueryExecutor { public class PlainTextExecutor implements ChatQueryExecutor {
private static final String APP_KEY = "SMALL_TALK"; private static final String APP_KEY = "SMALL_TALK";
private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to.\n" private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to."
+ "#Task: Respond quickly and nicely to the user." + "\n#Task: Respond quickly and nicely to the user."
+ "#Rules: 1.ALWAYS use the same language as the input.\n" + "#History Inputs: %s\n" + "\n#Rules: 1.ALWAYS use the same language as the `#Current Input`."
+ "#Current Input: %s\n" + "#Your response: "; + "\n#History Inputs: %s" + "\n#Current Input: %s" + "\n#Response: ";
public PlainTextExecutor() { public PlainTextExecutor() {
ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("闲聊对话") ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("闲聊对话")
.description("直接将原始输入透传大模型").enable(true).build()); .description("直接将原始输入透传大模型").enable(false).build());
} }
@Override @Override

View File

@@ -49,7 +49,7 @@ public class MemoryReviewTask {
private AgentService agentService; private AgentService agentService;
public MemoryReviewTask() { public MemoryReviewTask() {
ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("记忆启用评估") ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("记忆启用评估")
.description("通过大模型对记忆做正确性评估以决定是否启用").enable(false).build()); .description("通过大模型对记忆做正确性评估以决定是否启用").enable(false).build());
} }

View File

@@ -77,13 +77,13 @@ public class NL2SQLParser implements ChatQueryParser {
+ "#Examples: {{examples}}\n" + "#Response: "; + "#Examples: {{examples}}\n" + "#Response: ";
public NL2SQLParser() { public NL2SQLParser() {
ChatAppManager.register( ChatAppManager.register(APP_KEY_MULTI_TURN,
ChatApp.builder().key(APP_KEY_MULTI_TURN).prompt(REWRITE_MULTI_TURN_INSTRUCTION) ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION)
.name("多轮对话改写").description("通过大模型根据历史对话来改写本轮对话").enable(false).build()); .name("多轮对话改写").description("通过大模型根据历史对话来改写本轮对话").enable(false).build());
ChatAppManager.register(ChatApp.builder().key(APP_KEY_ERROR_MESSAGE) ChatAppManager.register(APP_KEY_ERROR_MESSAGE,
.prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写") ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION)
.description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build()); .name("异常提示改写").description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build());
} }
@Override @Override

View File

@@ -16,8 +16,8 @@ import com.tencent.supersonic.common.util.ChatAppManager;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map;
@RestController @RestController
@RequestMapping({"/api/chat/model", "/openapi/chat/model"}) @RequestMapping({"/api/chat/model", "/openapi/chat/model"})
@@ -51,8 +51,8 @@ public class ChatModelController {
} }
@RequestMapping("/getModelAppList") @RequestMapping("/getModelAppList")
public List<ChatApp> getModelAppList() { public Map<String, ChatApp> getChatAppList() {
return new ArrayList(ChatAppManager.getAllApps().values()); return ChatAppManager.getAllApps();
} }
@RequestMapping("/getModelParameters") @RequestMapping("/getModelParameters")

View File

@@ -15,6 +15,7 @@ import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcess
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor; import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.chat.server.service.ChatQueryService; import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.util.ComponentFactory; import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.chat.server.util.QueryReqConverter;
@@ -86,6 +87,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
private SemanticLayerService semanticLayerService; private SemanticLayerService semanticLayerService;
@Autowired @Autowired
private AgentService agentService; private AgentService agentService;
@Autowired
private ChatModelService chatModelService;
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers(); private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors(); private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
@@ -168,6 +171,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
ParseContext parseContext = new ParseContext(); ParseContext parseContext = new ParseContext();
BeanMapper.mapper(chatParseReq, parseContext); BeanMapper.mapper(chatParseReq, parseContext);
Agent agent = agentService.getAgent(chatParseReq.getAgentId()); Agent agent = agentService.getAgent(chatParseReq.getAgentId());
agent.getChatAppConfig().values().forEach(c -> c.setChatModelConfig(
chatModelService.getChatModel(c.getChatModelId()).getConfig()));
parseContext.setAgent(agent); parseContext.setAgent(agent);
return parseContext; return parseContext;
} }

View File

@@ -11,7 +11,6 @@ import lombok.NoArgsConstructor;
@AllArgsConstructor @AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
public class ChatApp { public class ChatApp {
private String key;
private String name; private String name;
private String description; private String description;
private String prompt; private String prompt;

View File

@@ -8,8 +8,8 @@ import java.util.Map;
public class ChatAppManager { public class ChatAppManager {
private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap(); private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap();
public static void register(ChatApp chatApp) { public static void register(String key, ChatApp app) {
chatApps.put(chatApp.getKey(), chatApp); chatApps.put(key, app);
} }
public static Map<String, ChatApp> getAllApps() { public static Map<String, ChatApp> getAllApps() {

View File

@@ -39,8 +39,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
+ "\n#Question:{{question}} #InputSQL:{{sql}} #Response:"; + "\n#Question:{{question}} #InputSQL:{{sql}} #Response:";
public LLMSqlCorrector() { public LLMSqlCorrector() {
ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("语义SQL修正") ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL修正")
.description("").enable(false).build()); .description("通过大模型对解析S2SQL做二次修正").enable(false).build());
} }
@Data @Data

View File

@@ -43,7 +43,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
+ "\n#Question: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}"; + "\n#Question: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
public OnePassSCSqlGenStrategy() { public OnePassSCSqlGenStrategy() {
ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("语义SQL解析") ChatAppManager.register(APP_KEY, ChatApp.builder().prompt(INSTRUCTION).name("语义SQL解析")
.description("通过大模型做语义解析生成S2SQL").enable(true).build()); .description("通过大模型做语义解析生成S2SQL").enable(true).build());
} }