(improvement)(chat|common|headless|webapp) 结果分析,改写伪流式输出,加快响应速度 (#2395)

This commit is contained in:
guilinlewis
2025-10-22 15:37:50 +08:00
committed by GitHub
parent 9857256488
commit 04b1edb2e2
20 changed files with 217 additions and 23 deletions

View File

@@ -18,4 +18,5 @@ public class ChatExecuteReq {
private int parseId;
private String queryText;
private boolean saveAnswer;
private boolean streamingResult;
}

View File

@@ -32,6 +32,7 @@ import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.provider.ModelProvider;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -171,10 +172,6 @@ public class NL2SQLParser implements ChatQueryParser {
return;
}
// derive mapping result of current question and parsing result of last question.
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
MapResp currentMapResult = chatLayerService.map(queryNLReq);
List<QueryResp> historyQueries =
getHistoryQueries(parseContext.getRequest().getChatId(), 1);
if (historyQueries.isEmpty()) {
@@ -182,12 +179,18 @@ public class NL2SQLParser implements ChatQueryParser {
}
QueryResp lastQuery = historyQueries.get(0);
SemanticParseInfo lastParseInfo = lastQuery.getParseInfos().get(0);
Long dataId = lastParseInfo.getDataSetId();
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
if (StringUtils.isBlank(histSQL)) // 优化性能,如果问答不是chat bi 则无需重写,因为数据都不全
return;
// derive mapping result of current question and parsing result of last question.
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
MapResp currentMapResult = chatLayerService.map(queryNLReq); // 优化性能 ,只有满足条件才mapping
Long dataId = lastParseInfo.getDataSetId();
String curtMapStr =
generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
String histMapStr = generateSchemaPrompt(lastParseInfo.getElementMatches());
String histSQL = lastParseInfo.getSqlInfo().getCorrectedS2SQL();
Map<String, Object> variables = new HashMap<>();
variables.put("current_question", currentMapResult.getQueryText());

View File

@@ -1,13 +1,20 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
@@ -24,9 +31,11 @@ import java.util.Objects;
* DataInterpretProcessor interprets query result to make it more readable to the users.
*/
public class DataInterpretProcessor implements ExecuteResultProcessor {
public static String tip = "AI 回答中...\r\n";
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
private static Map<Long, StringBuffer> resultCache = new HashMap<>();
public static final String APP_KEY = "DATA_INTERPRETER";
private static final String INSTRUCTION = ""
+ "#Role: You are a data expert who communicates with business users everyday."
@@ -41,6 +50,16 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
.appModule(AppModule.CHAT).description("通过大模型对结果数据做提炼总结").enable(false).build());
}
public static String getTextSummary(Long queryId) {
if (resultCache.get(queryId) != null) {
return resultCache.get(queryId).toString();
}
return "";
}
public static Map<Long, StringBuffer> getResultCache() {
return resultCache;
}
@Override
public boolean accept(ExecuteContext executeContext) {
@@ -71,14 +90,49 @@ public class DataInterpretProcessor implements ExecuteResultProcessor {
variable.put("data", queryResult.getTextResult());
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variable);
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatApp.getChatModelConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String anwser = response.content().text();
keyPipelineLog.info("DataInterpretProcessor modelReq:\n{} \nmodelResp:\n{}", prompt.text(),
anwser);
if (StringUtils.isNotBlank(anwser)) {
queryResult.setTextSummary(anwser);
if (executeContext.getRequest().isStreamingResult()) {
StreamingChatLanguageModel chatLanguageModel =
ModelProvider.getChatStreamingModel(chatApp.getChatModelConfig());
final Long queryId = executeContext.getRequest().getQueryId();
resultCache.put(queryId, new StringBuffer(tip));
chatLanguageModel.generate(prompt.toUserMessage(),
new StreamingResponseHandler<AiMessage>() {
@Override
public void onNext(String token) {
resultCache.get(queryId).append(token);
}
@Override
public void onComplete(Response<AiMessage> response) {
ChatQueryRepository chatQueryRepository =
ContextUtils.getBean(ChatQueryRepository.class);
ChatQueryDO chatQueryDO = chatQueryRepository.getChatQueryDO(queryId);
JSONObject queryResult = JSON.parseObject(chatQueryDO.getQueryResult());
queryResult.put("textSummary",
resultCache.get(queryId).toString().substring(tip.length()));
chatQueryDO.setQueryResult(queryResult.toJSONString());
chatQueryRepository.updateChatQuery(chatQueryDO);
resultCache.remove(queryId);
}
@Override
public void onError(Throwable error) {
error.printStackTrace();
resultCache.remove(queryId);
}
});
} else {
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatApp.getChatModelConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String anwser = response.content().text();
keyPipelineLog.info("DataInterpretProcessor modelReq:\n{} \nmodelResp:\n{}",
prompt.text(), anwser);
if (StringUtils.isNotBlank(anwser)) {
queryResult.setTextSummary(anwser);
}
}
}
}

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
@@ -50,6 +51,14 @@ public class ChatQueryController {
return chatQueryService.execute(chatExecuteReq);
}
@PostMapping("getExecuteSummary")
public Object getExecuteSummary(@RequestBody ChatExecuteReq chatExecuteReq,
HttpServletRequest request, HttpServletResponse response) {
chatExecuteReq.setUser(UserHolder.findUser(request, response));
QueryResult res = chatQueryService.getTextSummary(chatExecuteReq);
return res;
}
@PostMapping("/")
public Object query(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
HttpServletResponse response) throws Exception {

View File

@@ -35,6 +35,8 @@ public interface ChatManageService {
QueryResp getChatQuery(Long queryId);
ChatQueryDO getChatQueryDO(Long queryId);
List<QueryResp> getChatQueries(Integer chatId);
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId);

View File

@@ -19,6 +19,8 @@ public interface ChatQueryService {
QueryResult execute(ChatExecuteReq chatExecuteReq) throws Exception;
QueryResult getTextSummary(ChatExecuteReq chatExecuteReq);
QueryResult parseAndExecute(ChatParseReq chatParseReq);
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;

View File

@@ -123,6 +123,11 @@ public class ChatManageServiceImpl implements ChatManageService {
return chatQueryRepository.getChatQuery(queryId);
}
@Override
public ChatQueryDO getChatQueryDO(Long queryId) {
return chatQueryRepository.getChatQueryDO(queryId);
}
@Override
public List<QueryResp> getChatQueries(Integer chatId) {
List<QueryResp> queries = chatQueryRepository.getChatQueries(chatId);

View File

@@ -1,5 +1,7 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson2.JSON;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
@@ -9,8 +11,10 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.processor.execute.DataInterpretProcessor;
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
import com.tencent.supersonic.chat.server.service.AgentService;
@@ -143,6 +147,21 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return queryResult;
}
@Override
public QueryResult getTextSummary(ChatExecuteReq chatExecuteReq) {
String text = DataInterpretProcessor.getTextSummary(chatExecuteReq.getQueryId());
if (StringUtils.isNotBlank(text)) {
QueryResult res = new QueryResult();
res.setTextSummary(text);
res.setQueryId(chatExecuteReq.getQueryId());
return res;
} else {
ChatQueryDO chatQueryDo = chatManageService.getChatQueryDO(chatExecuteReq.getQueryId());
QueryResult res = JSON.parseObject(chatQueryDo.getQueryResult(), QueryResult.class);
return res;
}
}
@Override
public QueryResult parseAndExecute(ChatParseReq chatParseReq) {
ChatParseResp parseResp = parse(chatParseReq);

View File

@@ -108,6 +108,7 @@ public class PluginServiceImpl implements PluginService {
if (StringUtils.isNotBlank(pluginQueryReq.getCreatedBy())) {
queryWrapper.lambda().eq(PluginDO::getCreatedBy, pluginQueryReq.getCreatedBy());
}
queryWrapper.orderByAsc("name");
List<PluginDO> pluginDOS = pluginRepository.query(queryWrapper);
if (StringUtils.isNotBlank(pluginQueryReq.getPattern())) {
pluginDOS = pluginDOS.stream()

View File

@@ -16,6 +16,7 @@ import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
@Service
public class SystemConfigServiceImpl extends ServiceImpl<SystemConfigMapper, SystemConfigDO>
@@ -38,8 +39,8 @@ public class SystemConfigServiceImpl extends ServiceImpl<SystemConfigMapper, Sys
return systemConfigDb;
}
private SystemConfig getSystemConfigFromDB() {
List<SystemConfigDO> list = list();
private SystemConfig getSystemConfigFromDB() { // 加上id ,如果有多条记录,会出错
List<SystemConfigDO> list = this.lambdaQuery().eq(SystemConfigDO::getId, 1).list();
if (CollectionUtils.isEmpty(list)) {
SystemConfig systemConfig = new SystemConfig();
systemConfig.setId(1);

View File

@@ -7,6 +7,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.dify.DifyAiChatModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@@ -25,6 +26,11 @@ public class DifyModelFactory implements ModelFactory, InitializingBean {
.modelName(modelConfig.getModelName()).timeOut(modelConfig.getTimeOut()).build();
}
@Override
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
throw new RuntimeException("待开发");
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return OpenAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@@ -35,6 +36,11 @@ public class InMemoryModelFactory implements ModelFactory, InitializingBean {
return EmbeddingModelConstant.BGE_SMALL_ZH_MODEL;
}
@Override
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
throw new RuntimeException("待开发");
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(PROVIDER, this);

View File

@@ -6,6 +6,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import dev.langchain4j.model.localai.LocalAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@@ -27,6 +28,11 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean {
.build();
}
@Override
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
throw new RuntimeException("待开发");
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
return LocalAiEmbeddingModel.builder().baseUrl(embeddingModel.getBaseUrl())

View File

@@ -4,9 +4,12 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
public interface ModelFactory {
ChatLanguageModel createChatModel(ChatModelConfig modelConfig);
OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig);
EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel);
}

View File

@@ -5,7 +5,9 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.EmbeddingModelConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import org.apache.commons.lang3.StringUtils;
import java.util.HashMap;
@@ -41,6 +43,20 @@ public class ModelProvider {
"Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
}
public static StreamingChatLanguageModel getChatStreamingModel(ChatModelConfig modelConfig) {
if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider())
|| StringUtils.isBlank(modelConfig.getBaseUrl())) {
modelConfig = DEMO_CHAT_MODEL;
}
ModelFactory modelFactory = factories.get(modelConfig.getProvider().toUpperCase());
if (modelFactory != null) {
return modelFactory.createChatStreamingModel(modelConfig);
}
throw new RuntimeException(
"Unsupported ChatLanguageModel provider: " + modelConfig.getProvider());
}
public static EmbeddingModel getEmbeddingModel() {
return getEmbeddingModel(null);
}

View File

@@ -6,6 +6,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.ollama.OllamaChatModel;
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@@ -28,6 +29,11 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean {
.logResponses(modelConfig.getLogResponses()).build();
}
@Override
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
throw new RuntimeException("待开发");
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return OllamaEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl())

View File

@@ -6,6 +6,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import dev.langchain4j.model.openai.OpenAiStreamingChatModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@@ -37,6 +38,16 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
return openAiChatModelBuilder.build();
}
@Override
public OpenAiStreamingChatModel createChatStreamingModel(ChatModelConfig modelConfig) {
return OpenAiStreamingChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses()).build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
return OpenAiEmbeddingModel.builder().baseUrl(embeddingModel.getBaseUrl())

View File

@@ -117,9 +117,17 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
dictWord.setAlias(word.toLowerCase());
String[] split = nature.split(DictWordType.NATURE_SPILT);
if (split.length >= 2) {
Long dimId = Long.parseLong(
nature.split(DictWordType.NATURE_SPILT)[split.length - 1]);
KnowledgeBaseService.addDimValueAlias(dimId, Arrays.asList(dictWord));
try {
Long dimId = Long.parseLong(
nature.split(DictWordType.NATURE_SPILT)[split.length - 1]);
KnowledgeBaseService.addDimValueAlias(dimId,
Arrays.asList(dictWord));
} catch (NumberFormatException e) {
logger.warning(path + " : 非标准文件不存入KnowledgeBaseService");
return true;
}
}
}
}

View File

@@ -10,7 +10,14 @@ import {
SimilarQuestionType,
} from '../../common/type';
import { createContext, useEffect, useRef, useState } from 'react';
import { chatExecute, chatParse, queryData, deleteQuery, switchEntity } from '../../service';
import {
chatExecute,
chatParse,
queryData,
deleteQuery,
switchEntity,
getExecuteSummary,
} from '../../service';
import { PARSE_ERROR_TIP, PREFIX_CLS, SEARCH_EXCEPTION_TIP } from '../../common/constants';
import { message, Spin } from 'antd';
import IconFont from '../IconFont';
@@ -169,7 +176,7 @@ const ChatItem: React.FC<Props> = ({
setExecuteLoading(true);
}
try {
const res: any = await chatExecute(msg, conversationId!, parseInfoValue, agentId);
const res: any = await chatExecute(msg, conversationId!, parseInfoValue, agentId, true);
const valid = updateData(res);
onMsgDataLoaded?.(
{
@@ -180,6 +187,20 @@ const ChatItem: React.FC<Props> = ({
valid,
isRefresh
);
const queryId = parseInfoValue.queryId; // 伪流式 大模型输出
if (queryId != undefined && res.data.queryState != 'INVALID') {
const getSummary = async (data: any, queryId: number) => {
const res2: any = await getExecuteSummary(queryId);
if (res2.data.queryMode == null) {
res2.data = { ...data, textSummary: res2.data.textSummary };
setData(res2.data);
setTimeout(() => getSummary(data, queryId), 500);
} else {
setData(res2.data);
}
};
setTimeout(() => getSummary(res.data, queryId), 500);
}
} catch (e) {
const tip = SEARCH_EXCEPTION_TIP;
setExecuteTip(SEARCH_EXCEPTION_TIP);
@@ -423,6 +444,10 @@ const ChatItem: React.FC<Props> = ({
return result;
}, {});
});
if (exportData.length === 0) {
message.error('该条消息暂不支持该操作');
return;
}
exportCsvFile(exportData);
}
};

View File

@@ -79,7 +79,8 @@ export function chatExecute(
queryText: string,
chatId: number,
parseInfo: ChatContextType,
agentId?: number
agentId?: number,
streamingResult?:boolean
) {
return axios.post<MsgDataType>(`${prefix}/chat/query/execute`, {
queryText,
@@ -87,6 +88,15 @@ export function chatExecute(
chatId: chatId || DEFAULT_CHAT_ID,
queryId: parseInfo.queryId,
parseId: parseInfo.id,
streamingResult:streamingResult
});
}
export function getExecuteSummary(
queryId: number
) {
return axios.post<MsgDataType>(`${prefix}/chat/query/getExecuteSummary`, {
queryId: queryId,
});
}