(improvement)(chat)Add SmallTalkDemo and support multi-turn conversation in PLAIN_TEXT mode.

This commit is contained in:
jerryjzhang
2024-06-26 10:19:45 +08:00
parent 82b2552d9d
commit 1e5cfb51df
4 changed files with 94 additions and 3 deletions

View File

@@ -1,10 +1,14 @@
package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.parser.ParserConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.S2ChatModelProvider;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import dev.langchain4j.data.message.AiMessage;
@@ -14,17 +18,29 @@ import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
public class PlainTextExecutor implements ChatExecutor {
private static final String INSTRUCTION = ""
+ "#Role: You are a nice person to talked to.\n"
+ "#Task: You will have a small talk with the user, please respond quickly and nicely."
+ "#History Conversations: %s\n"
+ "#Current User Input: %s\n"
+ "#Your response: ";
@Override
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) {
return null;
}
Prompt prompt = PromptTemplate.from(chatExecuteContext.getQueryText())
.apply(Collections.EMPTY_MAP);
String promptStr = String.format(INSTRUCTION, getHistoryInputs(chatExecuteContext),
chatExecuteContext.getQueryText());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
@@ -39,4 +55,40 @@ public class PlainTextExecutor implements ChatExecutor {
return result;
}
private String getHistoryInputs(ChatExecuteContext chatExecuteContext) {
StringBuilder historyInput = new StringBuilder();
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig();
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
Boolean multiTurnConfig = agentMultiTurnConfig != null
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
if (Boolean.TRUE.equals(multiTurnConfig)) {
List<ParseResp> parseResps = getHistoryParseResult(chatExecuteContext.getChatId(), 5);
parseResps.stream().forEach(p -> {
historyInput.append(p.getQueryText());
historyInput.append(";");
});
}
return historyInput.toString();
}
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
Math.min(multiNum, contextualParseInfoList.size()));
Collections.reverse(contextualList);
return contextualList;
}
}

View File

@@ -4,7 +4,9 @@ import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public class PlainTextParser implements ChatParser {
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (chatParseContext.getAgent().containsAnyTool()) {
@@ -15,4 +17,5 @@ public class PlainTextParser implements ChatParser {
parseInfo.setQueryMode("PLAIN_TEXT");
parseResp.getSelectedParses().add(parseInfo);
}
}

View File

@@ -0,0 +1,36 @@
package com.tencent.supersonic.demo;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
@Component
@Slf4j
@Order(2)
public class SmallTalkDemo extends S2BaseDemo {
public void doRun() {
Agent agent = new Agent();
agent.setName("来闲聊");
agent.setDescription("直接与大模型对话,验证连通性");
agent.setStatus(1);
agent.setEnableSearch(0);
AgentConfig agentConfig = new AgentConfig();
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agent.setExamples(Lists.newArrayList("如何才能变帅",
"如何才能赚更多钱", "如何才能世界和平"));
agentService.createAgent(agent, User.getFakeUser());
}
@Override
boolean checkNeedToRun() {
return true;
}
}

View File

@@ -71,7 +71,7 @@ s2:
path: /tmp
demo:
names: S2VisitsDemo,S2ArtistDemo
names: S2VisitsDemo,S2ArtistDemo,SmallTalkDemo
enableLLM: true
# swagger配置