mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 02:46:56 +00:00
(feature)(chat)Introduce new plain_text mode to allow users to talk to LLM directly.
This commit is contained in:
@@ -12,6 +12,7 @@ import lombok.NoArgsConstructor;
|
||||
@AllArgsConstructor
|
||||
public class ChatExecuteReq {
|
||||
private User user;
|
||||
private Integer agentId;
|
||||
private Long queryId;
|
||||
private Integer chatId;
|
||||
private int parseId;
|
||||
|
||||
@@ -71,6 +71,10 @@ public class Agent extends RecordInfo {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public boolean containsPluginTool() {
|
||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN));
|
||||
}
|
||||
|
||||
public boolean containsLLMParserTool() {
|
||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
|
||||
}
|
||||
@@ -84,6 +88,19 @@ public class Agent extends RecordInfo {
|
||||
|| !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
|
||||
}
|
||||
|
||||
public boolean containsAnyTool() {
|
||||
Map map = JSONObject.parseObject(agentConfig, Map.class);
|
||||
if (CollectionUtils.isEmpty(map)) {
|
||||
return false;
|
||||
}
|
||||
List<Map> toolList = (List) map.get("tools");
|
||||
if (CollectionUtils.isEmpty(toolList)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public Set<Long> getDataSetIds() {
|
||||
Set<Long> dataSetIds = getDataSetIds(null);
|
||||
if (containsAllModel(dataSetIds)) {
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
|
||||
import java.util.Collections;
|
||||
|
||||
public class PlainTextExecutor implements ChatExecutor {
|
||||
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
Prompt prompt = PromptTemplate.from(chatExecuteContext.getQueryText())
|
||||
.apply(Collections.EMPTY_MAP);
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
QueryResult result = new QueryResult();
|
||||
result.setQueryState(QueryState.SUCCESS);
|
||||
result.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
|
||||
result.setTextResult(response.content().text());
|
||||
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -15,16 +14,15 @@ public class SqlExecutor implements ChatExecutor {
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
||||
if (PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
|
||||
return null;
|
||||
}
|
||||
ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteContext);
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
QueryResult queryResult = chatQueryService.performExecution(executeQueryReq);
|
||||
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
||||
queryResult.getQueryResults());
|
||||
queryResult.setTextResult(textResult);
|
||||
if (queryResult != null) {
|
||||
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
||||
queryResult.getQueryResults());
|
||||
queryResult.setTextResult(textResult);
|
||||
}
|
||||
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
|
||||
@@ -52,6 +52,10 @@ public class MultiTurnParser implements ChatParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!chatParseContext.getAgent().containsAnyTool()) {
|
||||
return;
|
||||
}
|
||||
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
|
||||
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
|
||||
@@ -15,6 +15,10 @@ public class NL2PluginParser implements ChatParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!chatParseContext.getAgent().containsPluginTool()) {
|
||||
return;
|
||||
}
|
||||
|
||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||
pluginRecognizer.recognize(chatParseContext, parseResp);
|
||||
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||
|
||||
@@ -23,14 +23,11 @@ public class NL2SQLParser implements ChatParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!chatParseContext.enableNL2SQL()) {
|
||||
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||
return;
|
||||
}
|
||||
if (checkSkip(parseResp)) {
|
||||
return;
|
||||
}
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
||||
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public class PlainTextParser implements ChatParser {
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (chatParseContext.getAgent().containsAnyTool()) {
|
||||
return;
|
||||
}
|
||||
|
||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
parseInfo.setQueryMode("PLAIN_TEXT");
|
||||
parseResp.getSelectedParses().add(parseInfo);
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import lombok.Data;
|
||||
@Data
|
||||
public class ChatExecuteContext {
|
||||
private User user;
|
||||
private Integer agentId;
|
||||
private Long queryId;
|
||||
private Integer chatId;
|
||||
private int parseId;
|
||||
|
||||
@@ -26,9 +26,10 @@ public class EntityInfoProcessor implements ParseResultProcessor {
|
||||
}
|
||||
selectedParses.forEach(parseInfo -> {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
if (QueryManager.containsRuleQuery(queryMode)) {
|
||||
if (QueryManager.containsRuleQuery(queryMode) || "PLAIN".equals(queryMode)) {
|
||||
return;
|
||||
}
|
||||
|
||||
//1. set entity info
|
||||
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
||||
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId());
|
||||
|
||||
@@ -90,10 +90,14 @@ public class ChatServiceImpl implements ChatService {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (ExecuteResultProcessor processor : executeResultProcessors) {
|
||||
processor.process(chatExecuteContext, queryResult);
|
||||
|
||||
if (queryResult != null) {
|
||||
for (ExecuteResultProcessor processor : executeResultProcessors) {
|
||||
processor.process(chatExecuteContext, queryResult);
|
||||
}
|
||||
saveQueryResult(chatExecuteReq, queryResult);
|
||||
}
|
||||
saveQueryResult(chatExecuteReq, queryResult);
|
||||
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import java.util.TreeSet;
|
||||
public class SemanticParseInfo {
|
||||
|
||||
private Integer id;
|
||||
private String queryMode;
|
||||
private String queryMode = "PLAIN_TEXT";
|
||||
private SchemaElement dataSet;
|
||||
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
||||
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
||||
|
||||
@@ -43,12 +43,12 @@ public class QueryManager {
|
||||
|
||||
private static SemanticQuery getSemanticQuery(String queryMode, SemanticQuery semanticQuery) {
|
||||
if (Objects.isNull(semanticQuery)) {
|
||||
throw new RuntimeException("no supported queryMode :" + queryMode);
|
||||
return null;
|
||||
}
|
||||
try {
|
||||
return semanticQuery.getClass().getDeclaredConstructor().newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("no supported queryMode :" + queryMode);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -56,11 +56,13 @@ com.tencent.supersonic.headless.server.processor.ResultProcessor=\
|
||||
com.tencent.supersonic.chat.server.parser.ChatParser=\
|
||||
com.tencent.supersonic.chat.server.parser.NL2PluginParser, \
|
||||
com.tencent.supersonic.chat.server.parser.MultiTurnParser,\
|
||||
com.tencent.supersonic.chat.server.parser.NL2SQLParser
|
||||
com.tencent.supersonic.chat.server.parser.NL2SQLParser,\
|
||||
com.tencent.supersonic.chat.server.parser.PlainTextParser
|
||||
|
||||
com.tencent.supersonic.chat.server.executor.ChatExecutor=\
|
||||
com.tencent.supersonic.chat.server.executor.PluginExecutor, \
|
||||
com.tencent.supersonic.chat.server.executor.SqlExecutor
|
||||
com.tencent.supersonic.chat.server.executor.SqlExecutor,\
|
||||
com.tencent.supersonic.chat.server.executor.PlainTextExecutor
|
||||
|
||||
com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\
|
||||
com.tencent.supersonic.chat.server.plugin.recognize.embedding.EmbeddingRecallRecognizer
|
||||
|
||||
@@ -5,7 +5,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.ConfigService;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -31,8 +30,6 @@ public class BaseTest extends BaseApplication {
|
||||
@Autowired
|
||||
protected ChatService chatService;
|
||||
@Autowired
|
||||
protected ConfigService configService;
|
||||
@Autowired
|
||||
protected AgentService agentService;
|
||||
|
||||
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception {
|
||||
@@ -61,6 +58,7 @@ public class BaseTest extends BaseApplication {
|
||||
.queryText(parseResp.getQueryText())
|
||||
.user(DataUtils.getUser())
|
||||
.parseId(parseInfo.getId())
|
||||
.agentId(agentId)
|
||||
.queryId(parseResp.getQueryId())
|
||||
.saveAnswer(false)
|
||||
.build();
|
||||
|
||||
Reference in New Issue
Block a user