mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
(improvement)(chat|common|headless|webapp) 结果分析,改写伪流式输出,加快响应速度 (#2395)
This commit is contained in:
@@ -32,6 +32,7 @@ import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -171,10 +172,6 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
return;
|
||||
}
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
MapResp currentMapResult = chatLayerService.map(queryNLReq);
|
||||
|
||||
List<QueryResp> historyQueries =
|
||||
getHistoryQueries(parseContext.getRequest().getChatId(), 1);
|
||||
if (historyQueries.isEmpty()) {
|
||||
@@ -182,12 +179,18 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
}
|
||||
QueryResp lastQuery = historyQueries.get(0);
|
||||
SemanticParseInfo lastParseInfo = lastQuery.getParseInfos().get(0);
|
||||
Long dataId = lastParseInfo.getDataSetId();
|
||||
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
if (StringUtils.isBlank(histSQL)) // 优化性能,如果问答不是chat bi 则无需重写,因为数据都不全
|
||||
return;
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
MapResp currentMapResult = chatLayerService.map(queryNLReq); // 优化性能 ,只有满足条件才mapping
|
||||
|
||||
Long dataId = lastParseInfo.getDataSetId();
|
||||
String curtMapStr =
|
||||
generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
||||
String histMapStr = generateSchemaPrompt(lastParseInfo.getElementMatches());
|
||||
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("current_question", currentMapResult.getQueryText());
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.StreamingResponseHandler;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
@@ -24,9 +31,11 @@ import java.util.Objects;
|
||||
* DataInterpretProcessor interprets query result to make it more readable to the users.
|
||||
*/
|
||||
public class DataInterpretProcessor implements ExecuteResultProcessor {
|
||||
|
||||
public static String tip = "AI 回答中...\r\n";
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
private static Map<Long, StringBuffer> resultCache = new HashMap<>();
|
||||
|
||||
public static final String APP_KEY = "DATA_INTERPRETER";
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "#Role: You are a data expert who communicates with business users everyday."
|
||||
@@ -41,6 +50,16 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
|
||||
.appModule(AppModule.CHAT).description("通过大模型对结果数据做提炼总结").enable(false).build());
|
||||
}
|
||||
|
||||
public static String getTextSummary(Long queryId) {
|
||||
if (resultCache.get(queryId) != null) {
|
||||
return resultCache.get(queryId).toString();
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
public static Map<Long, StringBuffer> getResultCache() {
|
||||
return resultCache;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean accept(ExecuteContext executeContext) {
|
||||
@@ -71,14 +90,49 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
|
||||
variable.put("data", queryResult.getTextResult());
|
||||
|
||||
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable);
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String anwser = response.content().text();
|
||||
keyPipelineLog.info("DataInterpretProcessor modelReq:\n{} \nmodelResp:\n{}", prompt.text(),
|
||||
anwser);
|
||||
if (StringUtils.isNotBlank(anwser)) {
|
||||
queryResult.setTextSummary(anwser);
|
||||
if (executeContext.getRequest().isStreamingResult()) {
|
||||
StreamingChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatStreamingModel(chatApp.getChatModelConfig());
|
||||
final Long queryId = executeContext.getRequest().getQueryId();
|
||||
resultCache.put(queryId, new StringBuffer(tip));
|
||||
chatLanguageModel.generate(prompt.toUserMessage(),
|
||||
new StreamingResponseHandler<AiMessage>() {
|
||||
@Override
|
||||
public void onNext(String token) {
|
||||
resultCache.get(queryId).append(token);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onComplete(Response<AiMessage> response) {
|
||||
ChatQueryRepository chatQueryRepository =
|
||||
ContextUtils.getBean(ChatQueryRepository.class);
|
||||
ChatQueryDO chatQueryDO = chatQueryRepository.getChatQueryDO(queryId);
|
||||
JSONObject queryResult = JSON.parseObject(chatQueryDO.getQueryResult());
|
||||
queryResult.put("textSummary",
|
||||
resultCache.get(queryId).toString().substring(tip.length()));
|
||||
chatQueryDO.setQueryResult(queryResult.toJSONString());
|
||||
chatQueryRepository.updateChatQuery(chatQueryDO);
|
||||
resultCache.remove(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Throwable error) {
|
||||
error.printStackTrace();
|
||||
resultCache.remove(queryId);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String anwser = response.content().text();
|
||||
keyPipelineLog.info("DataInterpretProcessor modelReq:\n{} \nmodelResp:\n{}",
|
||||
prompt.text(), anwser);
|
||||
if (StringUtils.isNotBlank(anwser)) {
|
||||
queryResult.setTextSummary(anwser);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
@@ -50,6 +51,14 @@ public class ChatQueryController {
|
||||
return chatQueryService.execute(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("getExecuteSummary")
|
||||
public Object getExecuteSummary(@RequestBody ChatExecuteReq chatExecuteReq,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
chatExecuteReq.setUser(UserHolder.findUser(request, response));
|
||||
QueryResult res = chatQueryService.getTextSummary(chatExecuteReq);
|
||||
return res;
|
||||
}
|
||||
|
||||
@PostMapping("/")
|
||||
public Object query(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
|
||||
HttpServletResponse response) throws Exception {
|
||||
|
||||
@@ -35,6 +35,8 @@ public interface ChatManageService {
|
||||
|
||||
QueryResp getChatQuery(Long queryId);
|
||||
|
||||
ChatQueryDO getChatQueryDO(Long queryId);
|
||||
|
||||
List<QueryResp> getChatQueries(Integer chatId);
|
||||
|
||||
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId);
|
||||
|
||||
@@ -19,6 +19,8 @@ public interface ChatQueryService {
|
||||
|
||||
QueryResult execute(ChatExecuteReq chatExecuteReq) throws Exception;
|
||||
|
||||
QueryResult getTextSummary(ChatExecuteReq chatExecuteReq);
|
||||
|
||||
QueryResult parseAndExecute(ChatParseReq chatParseReq);
|
||||
|
||||
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
|
||||
|
||||
@@ -123,6 +123,11 @@ public class ChatManageServiceImpl implements ChatManageService {
|
||||
return chatQueryRepository.getChatQuery(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatQueryDO getChatQueryDO(Long queryId) {
|
||||
return chatQueryRepository.getChatQueryDO(queryId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<QueryResp> getChatQueries(Integer chatId) {
|
||||
List<QueryResp> queries = chatQueryRepository.getChatQueries(chatId);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson2.JSON;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
@@ -9,8 +11,10 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.DataInterpretProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
@@ -143,6 +147,21 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult getTextSummary(ChatExecuteReq chatExecuteReq) {
|
||||
String text = DataInterpretProcessor.getTextSummary(chatExecuteReq.getQueryId());
|
||||
if (StringUtils.isNotBlank(text)) {
|
||||
QueryResult res = new QueryResult();
|
||||
res.setTextSummary(text);
|
||||
res.setQueryId(chatExecuteReq.getQueryId());
|
||||
return res;
|
||||
} else {
|
||||
ChatQueryDO chatQueryDo = chatManageService.getChatQueryDO(chatExecuteReq.getQueryId());
|
||||
QueryResult res = JSON.parseObject(chatQueryDo.getQueryResult(), QueryResult.class);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult parseAndExecute(ChatParseReq chatParseReq) {
|
||||
ChatParseResp parseResp = parse(chatParseReq);
|
||||
|
||||
@@ -108,6 +108,7 @@ public class PluginServiceImpl implements PluginService {
|
||||
if (StringUtils.isNotBlank(pluginQueryReq.getCreatedBy())) {
|
||||
queryWrapper.lambda().eq(PluginDO::getCreatedBy, pluginQueryReq.getCreatedBy());
|
||||
}
|
||||
queryWrapper.orderByAsc("name");
|
||||
List<PluginDO> pluginDOS = pluginRepository.query(queryWrapper);
|
||||
if (StringUtils.isNotBlank(pluginQueryReq.getPattern())) {
|
||||
pluginDOS = pluginDOS.stream()
|
||||
|
||||
Reference in New Issue
Block a user