(feature)(chat)Introduce new plain_text mode to allow users to talk to LLM directly.

This commit is contained in:
jerryjzhang
2024-06-25 21:14:19 +08:00
parent db9a3fa056
commit d4cc53acae
15 changed files with 112 additions and 25 deletions

View File

@@ -12,6 +12,7 @@ import lombok.NoArgsConstructor;
@AllArgsConstructor
public class ChatExecuteReq {
private User user;
private Integer agentId;
private Long queryId;
private Integer chatId;
private int parseId;

View File

@@ -71,6 +71,10 @@ public class Agent extends RecordInfo {
.collect(Collectors.toList());
}
public boolean containsPluginTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN));
}
public boolean containsLLMParserTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
}
@@ -84,6 +88,19 @@ public class Agent extends RecordInfo {
|| !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
}
public boolean containsAnyTool() {
Map map = JSONObject.parseObject(agentConfig, Map.class);
if (CollectionUtils.isEmpty(map)) {
return false;
}
List<Map> toolList = (List) map.get("tools");
if (CollectionUtils.isEmpty(toolList)) {
return false;
}
return true;
}
public Set<Long> getDataSetIds() {
Set<Long> dataSetIds = getDataSetIds(null);
if (containsAllModel(dataSetIds)) {

View File

@@ -0,0 +1,42 @@
package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.S2ChatModelProvider;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import java.util.Collections;
public class PlainTextExecutor implements ChatExecutor {
@Override
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) {
return null;
}
Prompt prompt = PromptTemplate.from(chatExecuteContext.getQueryText())
.apply(Collections.EMPTY_MAP);
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
QueryResult result = new QueryResult();
result.setQueryState(QueryState.SUCCESS);
result.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
result.setTextResult(response.content().text());
return result;
}
}

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.util.ResultFormatter;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -15,16 +14,15 @@ public class SqlExecutor implements ChatExecutor {
@SneakyThrows
@Override
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
if (PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
return null;
}
ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteContext);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
QueryResult queryResult = chatQueryService.performExecution(executeQueryReq);
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
queryResult.getQueryResults());
queryResult.setTextResult(textResult);
if (queryResult != null) {
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
queryResult.getQueryResults());
queryResult.setTextResult(textResult);
}
return queryResult;
}

View File

@@ -52,6 +52,10 @@ public class MultiTurnParser implements ChatParser {
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!chatParseContext.getAgent().containsAnyTool()) {
return;
}
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));

View File

@@ -15,6 +15,10 @@ public class NL2PluginParser implements ChatParser {
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!chatParseContext.getAgent().containsPluginTool()) {
return;
}
pluginRecognizers.forEach(pluginRecognizer -> {
pluginRecognizer.recognize(chatParseContext, parseResp);
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),

View File

@@ -23,14 +23,11 @@ public class NL2SQLParser implements ChatParser {
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!chatParseContext.enableNL2SQL()) {
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
return;
}
if (checkSkip(parseResp)) {
return;
}
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public class PlainTextParser implements ChatParser {
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (chatParseContext.getAgent().containsAnyTool()) {
return;
}
SemanticParseInfo parseInfo = new SemanticParseInfo();
parseInfo.setQueryMode("PLAIN_TEXT");
parseResp.getSelectedParses().add(parseInfo);
}
}

View File

@@ -7,6 +7,7 @@ import lombok.Data;
@Data
public class ChatExecuteContext {
private User user;
private Integer agentId;
private Long queryId;
private Integer chatId;
private int parseId;

View File

@@ -26,9 +26,10 @@ public class EntityInfoProcessor implements ParseResultProcessor {
}
selectedParses.forEach(parseInfo -> {
String queryMode = parseInfo.getQueryMode();
if (QueryManager.containsRuleQuery(queryMode)) {
if (QueryManager.containsRuleQuery(queryMode) || "PLAIN".equals(queryMode)) {
return;
}
//1. set entity info
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId());

View File

@@ -90,10 +90,14 @@ public class ChatServiceImpl implements ChatService {
break;
}
}
for (ExecuteResultProcessor processor : executeResultProcessors) {
processor.process(chatExecuteContext, queryResult);
if (queryResult != null) {
for (ExecuteResultProcessor processor : executeResultProcessors) {
processor.process(chatExecuteContext, queryResult);
}
saveQueryResult(chatExecuteReq, queryResult);
}
saveQueryResult(chatExecuteReq, queryResult);
return queryResult;
}

View File

@@ -21,7 +21,7 @@ import java.util.TreeSet;
public class SemanticParseInfo {
private Integer id;
private String queryMode;
private String queryMode = "PLAIN_TEXT";
private SchemaElement dataSet;
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
private Set<SchemaElement> dimensions = new LinkedHashSet();

View File

@@ -43,12 +43,12 @@ public class QueryManager {
private static SemanticQuery getSemanticQuery(String queryMode, SemanticQuery semanticQuery) {
if (Objects.isNull(semanticQuery)) {
throw new RuntimeException("no supported queryMode :" + queryMode);
return null;
}
try {
return semanticQuery.getClass().getDeclaredConstructor().newInstance();
} catch (Exception e) {
throw new RuntimeException("no supported queryMode :" + queryMode);
return null;
}
}

View File

@@ -56,11 +56,13 @@ com.tencent.supersonic.headless.server.processor.ResultProcessor=\
com.tencent.supersonic.chat.server.parser.ChatParser=\
com.tencent.supersonic.chat.server.parser.NL2PluginParser, \
com.tencent.supersonic.chat.server.parser.MultiTurnParser,\
com.tencent.supersonic.chat.server.parser.NL2SQLParser
com.tencent.supersonic.chat.server.parser.NL2SQLParser,\
com.tencent.supersonic.chat.server.parser.PlainTextParser
com.tencent.supersonic.chat.server.executor.ChatExecutor=\
com.tencent.supersonic.chat.server.executor.PluginExecutor, \
com.tencent.supersonic.chat.server.executor.SqlExecutor
com.tencent.supersonic.chat.server.executor.SqlExecutor,\
com.tencent.supersonic.chat.server.executor.PlainTextExecutor
com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\
com.tencent.supersonic.chat.server.plugin.recognize.embedding.EmbeddingRecallRecognizer

View File

@@ -5,7 +5,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.ConfigService;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -31,8 +30,6 @@ public class BaseTest extends BaseApplication {
@Autowired
protected ChatService chatService;
@Autowired
protected ConfigService configService;
@Autowired
protected AgentService agentService;
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception {
@@ -61,6 +58,7 @@ public class BaseTest extends BaseApplication {
.queryText(parseResp.getQueryText())
.user(DataUtils.getUser())
.parseId(parseInfo.getId())
.agentId(agentId)
.queryId(parseResp.getQueryId())
.saveAnswer(false)
.build();