mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(feature)(chat)Introduce new plain_text mode to allow users to talk to LLM directly.
This commit is contained in:
@@ -12,6 +12,7 @@ import lombok.NoArgsConstructor;
|
|||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class ChatExecuteReq {
|
public class ChatExecuteReq {
|
||||||
private User user;
|
private User user;
|
||||||
|
private Integer agentId;
|
||||||
private Long queryId;
|
private Long queryId;
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private int parseId;
|
private int parseId;
|
||||||
|
|||||||
@@ -71,6 +71,10 @@ public class Agent extends RecordInfo {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean containsPluginTool() {
|
||||||
|
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN));
|
||||||
|
}
|
||||||
|
|
||||||
public boolean containsLLMParserTool() {
|
public boolean containsLLMParserTool() {
|
||||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
|
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
|
||||||
}
|
}
|
||||||
@@ -84,6 +88,19 @@ public class Agent extends RecordInfo {
|
|||||||
|| !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
|
|| !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean containsAnyTool() {
|
||||||
|
Map map = JSONObject.parseObject(agentConfig, Map.class);
|
||||||
|
if (CollectionUtils.isEmpty(map)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
List<Map> toolList = (List) map.get("tools");
|
||||||
|
if (CollectionUtils.isEmpty(toolList)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
public Set<Long> getDataSetIds() {
|
public Set<Long> getDataSetIds() {
|
||||||
Set<Long> dataSetIds = getDataSetIds(null);
|
Set<Long> dataSetIds = getDataSetIds(null);
|
||||||
if (containsAllModel(dataSetIds)) {
|
if (containsAllModel(dataSetIds)) {
|
||||||
|
|||||||
@@ -0,0 +1,42 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.executor;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||||
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.input.Prompt;
|
||||||
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
|
||||||
|
public class PlainTextExecutor implements ChatExecutor {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||||
|
if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
Prompt prompt = PromptTemplate.from(chatExecuteContext.getQueryText())
|
||||||
|
.apply(Collections.EMPTY_MAP);
|
||||||
|
|
||||||
|
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||||
|
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
||||||
|
|
||||||
|
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig());
|
||||||
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
|
|
||||||
|
QueryResult result = new QueryResult();
|
||||||
|
result.setQueryState(QueryState.SUCCESS);
|
||||||
|
result.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
|
||||||
|
result.setTextResult(response.content().text());
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
package com.tencent.supersonic.chat.server.executor;
|
package com.tencent.supersonic.chat.server.executor;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
@@ -15,16 +14,15 @@ public class SqlExecutor implements ChatExecutor {
|
|||||||
@SneakyThrows
|
@SneakyThrows
|
||||||
@Override
|
@Override
|
||||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
|
||||||
if (PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteContext);
|
ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteContext);
|
||||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||||
QueryResult queryResult = chatQueryService.performExecution(executeQueryReq);
|
QueryResult queryResult = chatQueryService.performExecution(executeQueryReq);
|
||||||
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
if (queryResult != null) {
|
||||||
queryResult.getQueryResults());
|
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
||||||
queryResult.setTextResult(textResult);
|
queryResult.getQueryResults());
|
||||||
|
queryResult.setTextResult(textResult);
|
||||||
|
}
|
||||||
|
|
||||||
return queryResult;
|
return queryResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ public class MultiTurnParser implements ChatParser {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||||
|
if (!chatParseContext.getAgent().containsAnyTool()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||||
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
|
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
|
||||||
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||||
|
|||||||
@@ -15,6 +15,10 @@ public class NL2PluginParser implements ChatParser {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||||
|
if (!chatParseContext.getAgent().containsPluginTool()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||||
pluginRecognizer.recognize(chatParseContext, parseResp);
|
pluginRecognizer.recognize(chatParseContext, parseResp);
|
||||||
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||||
|
|||||||
@@ -23,14 +23,11 @@ public class NL2SQLParser implements ChatParser {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||||
if (!chatParseContext.enableNL2SQL()) {
|
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (checkSkip(parseResp)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
|
||||||
|
|
||||||
|
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
||||||
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
|
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.parser;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
|
|
||||||
|
public class PlainTextParser implements ChatParser {
|
||||||
|
@Override
|
||||||
|
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||||
|
if (chatParseContext.getAgent().containsAnyTool()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||||
|
parseInfo.setQueryMode("PLAIN_TEXT");
|
||||||
|
parseResp.getSelectedParses().add(parseInfo);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,6 +7,7 @@ import lombok.Data;
|
|||||||
@Data
|
@Data
|
||||||
public class ChatExecuteContext {
|
public class ChatExecuteContext {
|
||||||
private User user;
|
private User user;
|
||||||
|
private Integer agentId;
|
||||||
private Long queryId;
|
private Long queryId;
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private int parseId;
|
private int parseId;
|
||||||
|
|||||||
@@ -26,9 +26,10 @@ public class EntityInfoProcessor implements ParseResultProcessor {
|
|||||||
}
|
}
|
||||||
selectedParses.forEach(parseInfo -> {
|
selectedParses.forEach(parseInfo -> {
|
||||||
String queryMode = parseInfo.getQueryMode();
|
String queryMode = parseInfo.getQueryMode();
|
||||||
if (QueryManager.containsRuleQuery(queryMode)) {
|
if (QueryManager.containsRuleQuery(queryMode) || "PLAIN".equals(queryMode)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
//1. set entity info
|
//1. set entity info
|
||||||
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
||||||
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId());
|
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId());
|
||||||
|
|||||||
@@ -90,10 +90,14 @@ public class ChatServiceImpl implements ChatService {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (ExecuteResultProcessor processor : executeResultProcessors) {
|
|
||||||
processor.process(chatExecuteContext, queryResult);
|
if (queryResult != null) {
|
||||||
|
for (ExecuteResultProcessor processor : executeResultProcessors) {
|
||||||
|
processor.process(chatExecuteContext, queryResult);
|
||||||
|
}
|
||||||
|
saveQueryResult(chatExecuteReq, queryResult);
|
||||||
}
|
}
|
||||||
saveQueryResult(chatExecuteReq, queryResult);
|
|
||||||
return queryResult;
|
return queryResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import java.util.TreeSet;
|
|||||||
public class SemanticParseInfo {
|
public class SemanticParseInfo {
|
||||||
|
|
||||||
private Integer id;
|
private Integer id;
|
||||||
private String queryMode;
|
private String queryMode = "PLAIN_TEXT";
|
||||||
private SchemaElement dataSet;
|
private SchemaElement dataSet;
|
||||||
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
||||||
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
||||||
|
|||||||
@@ -43,12 +43,12 @@ public class QueryManager {
|
|||||||
|
|
||||||
private static SemanticQuery getSemanticQuery(String queryMode, SemanticQuery semanticQuery) {
|
private static SemanticQuery getSemanticQuery(String queryMode, SemanticQuery semanticQuery) {
|
||||||
if (Objects.isNull(semanticQuery)) {
|
if (Objects.isNull(semanticQuery)) {
|
||||||
throw new RuntimeException("no supported queryMode :" + queryMode);
|
return null;
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
return semanticQuery.getClass().getDeclaredConstructor().newInstance();
|
return semanticQuery.getClass().getDeclaredConstructor().newInstance();
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException("no supported queryMode :" + queryMode);
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -56,11 +56,13 @@ com.tencent.supersonic.headless.server.processor.ResultProcessor=\
|
|||||||
com.tencent.supersonic.chat.server.parser.ChatParser=\
|
com.tencent.supersonic.chat.server.parser.ChatParser=\
|
||||||
com.tencent.supersonic.chat.server.parser.NL2PluginParser, \
|
com.tencent.supersonic.chat.server.parser.NL2PluginParser, \
|
||||||
com.tencent.supersonic.chat.server.parser.MultiTurnParser,\
|
com.tencent.supersonic.chat.server.parser.MultiTurnParser,\
|
||||||
com.tencent.supersonic.chat.server.parser.NL2SQLParser
|
com.tencent.supersonic.chat.server.parser.NL2SQLParser,\
|
||||||
|
com.tencent.supersonic.chat.server.parser.PlainTextParser
|
||||||
|
|
||||||
com.tencent.supersonic.chat.server.executor.ChatExecutor=\
|
com.tencent.supersonic.chat.server.executor.ChatExecutor=\
|
||||||
com.tencent.supersonic.chat.server.executor.PluginExecutor, \
|
com.tencent.supersonic.chat.server.executor.PluginExecutor, \
|
||||||
com.tencent.supersonic.chat.server.executor.SqlExecutor
|
com.tencent.supersonic.chat.server.executor.SqlExecutor,\
|
||||||
|
com.tencent.supersonic.chat.server.executor.PlainTextExecutor
|
||||||
|
|
||||||
com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\
|
com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\
|
||||||
com.tencent.supersonic.chat.server.plugin.recognize.embedding.EmbeddingRecallRecognizer
|
com.tencent.supersonic.chat.server.plugin.recognize.embedding.EmbeddingRecallRecognizer
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||||
import com.tencent.supersonic.chat.server.service.ConfigService;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
@@ -31,8 +30,6 @@ public class BaseTest extends BaseApplication {
|
|||||||
@Autowired
|
@Autowired
|
||||||
protected ChatService chatService;
|
protected ChatService chatService;
|
||||||
@Autowired
|
@Autowired
|
||||||
protected ConfigService configService;
|
|
||||||
@Autowired
|
|
||||||
protected AgentService agentService;
|
protected AgentService agentService;
|
||||||
|
|
||||||
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception {
|
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception {
|
||||||
@@ -61,6 +58,7 @@ public class BaseTest extends BaseApplication {
|
|||||||
.queryText(parseResp.getQueryText())
|
.queryText(parseResp.getQueryText())
|
||||||
.user(DataUtils.getUser())
|
.user(DataUtils.getUser())
|
||||||
.parseId(parseInfo.getId())
|
.parseId(parseInfo.getId())
|
||||||
|
.agentId(agentId)
|
||||||
.queryId(parseResp.getQueryId())
|
.queryId(parseResp.getQueryId())
|
||||||
.saveAnswer(false)
|
.saveAnswer(false)
|
||||||
.build();
|
.build();
|
||||||
|
|||||||
Reference in New Issue
Block a user