(improvement)(chat)Add SmallTalkDemo and support multi-turn conversation in PLAIN_TEXT mode.

This commit is contained in:
jerryjzhang
2024-06-26 10:19:45 +08:00
parent 82b2552d9d
commit 1e5cfb51df
4 changed files with 94 additions and 3 deletions

View File

@@ -1,10 +1,14 @@
package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.parser.ParserConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.S2ChatModelProvider;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import dev.langchain4j.data.message.AiMessage;
@@ -14,17 +18,29 @@ import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
public class PlainTextExecutor implements ChatExecutor {
private static final String INSTRUCTION = ""
+ "#Role: You are a nice person to talked to.\n"
+ "#Task: You will have a small talk with the user, please respond quickly and nicely."
+ "#History Conversations: %s\n"
+ "#Current User Input: %s\n"
+ "#Your response: ";
@Override
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) {
return null;
}
Prompt prompt = PromptTemplate.from(chatExecuteContext.getQueryText())
.apply(Collections.EMPTY_MAP);
String promptStr = String.format(INSTRUCTION, getHistoryInputs(chatExecuteContext),
chatExecuteContext.getQueryText());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
@@ -39,4 +55,40 @@ public class PlainTextExecutor implements ChatExecutor {
return result;
}
private String getHistoryInputs(ChatExecuteContext chatExecuteContext) {
StringBuilder historyInput = new StringBuilder();
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig();
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
Boolean multiTurnConfig = agentMultiTurnConfig != null
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
if (Boolean.TRUE.equals(multiTurnConfig)) {
List<ParseResp> parseResps = getHistoryParseResult(chatExecuteContext.getChatId(), 5);
parseResps.stream().forEach(p -> {
historyInput.append(p.getQueryText());
historyInput.append(";");
});
}
return historyInput.toString();
}
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
Math.min(multiNum, contextualParseInfoList.size()));
Collections.reverse(contextualList);
return contextualList;
}
}

View File

@@ -4,7 +4,9 @@ import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public class PlainTextParser implements ChatParser {
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (chatParseContext.getAgent().containsAnyTool()) {
@@ -15,4 +17,5 @@ public class PlainTextParser implements ChatParser {
parseInfo.setQueryMode("PLAIN_TEXT");
parseResp.getSelectedParses().add(parseInfo);
}
}