mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement][chat]Incorporate Response into Context objects.
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public interface ChatQueryParser {
|
||||
|
||||
void parse(ParseContext parseContext, ParseResp parseResp);
|
||||
void parse(ParseContext parseContext);
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.List;
|
||||
@@ -16,15 +15,15 @@ public class NL2PluginParser implements ChatQueryParser {
|
||||
ComponentFactory.getPluginRecognizers();
|
||||
|
||||
@Override
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
public void parse(ParseContext parseContext) {
|
||||
if (!parseContext.getAgent().containsPluginTool()) {
|
||||
return;
|
||||
}
|
||||
|
||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||
pluginRecognizer.recognize(parseContext, parseResp);
|
||||
pluginRecognizer.recognize(parseContext);
|
||||
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||
JsonUtil.toString(parseResp));
|
||||
JsonUtil.toString(parseContext.getResponse()));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
@@ -90,7 +91,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
public void parse(ParseContext parseContext) {
|
||||
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
|
||||
return;
|
||||
}
|
||||
@@ -98,14 +99,15 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
if (Objects.isNull(queryNLReq)) {
|
||||
return;
|
||||
}
|
||||
ParseResp parseResp = parseContext.getResponse();
|
||||
ChatParseReq parseReq = parseContext.getRequest();
|
||||
|
||||
if (!parseContext.getRequest().isDisableLLM()) {
|
||||
processMultiTurn(parseContext);
|
||||
}
|
||||
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
ChatContext chatCtx =
|
||||
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(parseReq.getChatId());
|
||||
if (chatCtx != null) {
|
||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||
}
|
||||
@@ -116,7 +118,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||
} else {
|
||||
if (!parseContext.getRequest().isDisableLLM()) {
|
||||
if (!parseReq.isDisableLLM()) {
|
||||
parseResp.setErrorMsg(rewriteErrorMessage(parseContext,
|
||||
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars()));
|
||||
}
|
||||
|
||||
@@ -7,14 +7,14 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
public class PlainTextParser implements ChatQueryParser {
|
||||
|
||||
@Override
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
public void parse(ParseContext parseContext) {
|
||||
if (parseContext.getAgent().containsAnyTool()) {
|
||||
return;
|
||||
}
|
||||
|
||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
parseInfo.setQueryMode("PLAIN_TEXT");
|
||||
parseResp.getSelectedParses().add(parseInfo);
|
||||
parseResp.setState(ParseResp.ParseState.COMPLETED);
|
||||
parseContext.getResponse().getSelectedParses().add(parseInfo);
|
||||
parseContext.getResponse().setState(ParseResp.ParseState.COMPLETED);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ import java.util.Set;
|
||||
/** PluginParser defines the basic process and common methods for recalling plugins. */
|
||||
public abstract class PluginRecognizer {
|
||||
|
||||
public void recognize(ParseContext parseContext, ParseResp parseResp) {
|
||||
public void recognize(ParseContext parseContext) {
|
||||
if (!checkPreCondition(parseContext)) {
|
||||
return;
|
||||
}
|
||||
@@ -39,7 +39,7 @@ public abstract class PluginRecognizer {
|
||||
if (pluginRecallResult == null) {
|
||||
return;
|
||||
}
|
||||
buildQuery(parseContext, parseResp, pluginRecallResult);
|
||||
buildQuery(parseContext, parseContext.getResponse(), pluginRecallResult);
|
||||
}
|
||||
|
||||
public abstract boolean checkPreCondition(ParseContext parseContext);
|
||||
|
||||
@@ -2,15 +2,18 @@ package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ParseContext {
|
||||
private ChatParseReq request;
|
||||
private ParseResp response;
|
||||
private Agent agent;
|
||||
|
||||
public ParseContext(ChatParseReq request) {
|
||||
this.request = request;
|
||||
response = new ParseResp(request.getQueryText());
|
||||
}
|
||||
|
||||
public boolean enableNL2SQL() {
|
||||
|
||||
@@ -2,10 +2,9 @@ package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
/** A ParseResultProcessor wraps things up before returning parsing results to the users. */
|
||||
public interface ParseResultProcessor extends ResultProcessor {
|
||||
|
||||
void process(ParseContext parseContext, ParseResp parseResp);
|
||||
void process(ParseContext parseContext);
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -23,13 +22,13 @@ import java.util.stream.Collectors;
|
||||
public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||
CompletableFuture.runAsync(() -> doProcess(parseResp, parseContext));
|
||||
public void process(ParseContext parseContext) {
|
||||
CompletableFuture.runAsync(() -> doProcess(parseContext));
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
private void doProcess(ParseResp parseResp, ParseContext parseContext) {
|
||||
Long queryId = parseResp.getQueryId();
|
||||
private void doProcess(ParseContext parseContext) {
|
||||
Long queryId = parseContext.getResponse().getQueryId();
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(
|
||||
parseContext.getRequest().getQueryText(), parseContext.getAgent().getId());
|
||||
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
||||
|
||||
@@ -9,7 +9,8 @@ import lombok.extern.slf4j.Slf4j;
|
||||
public class TimeCostProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||
public void process(ParseContext parseContext) {
|
||||
ParseResp parseResp = parseContext.getResponse();
|
||||
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
|
||||
parseResp.getParseTimeCost().setParseTime(System.currentTimeMillis() - parseStartTime
|
||||
- parseResp.getParseTimeCost().getSqlTime());
|
||||
|
||||
@@ -31,7 +31,7 @@ public interface ChatManageService {
|
||||
|
||||
PageInfo<QueryResp> queryInfo(PageQueryInfoReq pageQueryInfoReq, long chatId);
|
||||
|
||||
void createChatQuery(ChatParseReq chatParseReq, ParseResp parseResp);
|
||||
Long createChatQuery(ChatParseReq chatParseReq);
|
||||
|
||||
QueryResp getChatQuery(Long queryId);
|
||||
|
||||
|
||||
@@ -93,9 +93,8 @@ public class ChatManageServiceImpl implements ChatManageService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createChatQuery(ChatParseReq chatParseReq, ParseResp parseResp) {
|
||||
Long queryId = chatQueryRepository.createChatQuery(chatParseReq);
|
||||
parseResp.setQueryId(queryId);
|
||||
public Long createChatQuery(ChatParseReq chatParseReq) {
|
||||
return chatQueryRepository.createChatQuery(chatParseReq);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -105,15 +105,16 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
@Override
|
||||
public ParseResp parse(ChatParseReq chatParseReq) {
|
||||
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
|
||||
chatManageService.createChatQuery(chatParseReq, parseResp);
|
||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||
ParseResp parseResp = parseContext.getResponse();
|
||||
Long queryId = chatManageService.createChatQuery(chatParseReq);
|
||||
parseResp.setQueryId(queryId);
|
||||
|
||||
for (ChatQueryParser parser : chatQueryParsers) {
|
||||
parser.parse(parseContext, parseResp);
|
||||
parser.parse(parseContext);
|
||||
}
|
||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||
processor.process(parseContext, parseResp);
|
||||
processor.process(parseContext);
|
||||
}
|
||||
|
||||
chatParseReq.setQueryText(parseContext.getRequest().getQueryText());
|
||||
|
||||
@@ -120,7 +120,7 @@ public class RetrieveServiceImpl implements RetrieveService {
|
||||
return searchResults.stream().limit(RESULT_SIZE).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private List<Long> getPossibleDataSets(QueryNLReq queryCtx, List<S2Term> originals,
|
||||
private List<Long> getPossibleDataSets(QueryNLReq queryReq, List<S2Term> originals,
|
||||
Set<Long> dataSetIds) {
|
||||
if (CollectionUtils.isNotEmpty(dataSetIds)) {
|
||||
return new ArrayList<>(dataSetIds);
|
||||
@@ -128,8 +128,8 @@ public class RetrieveServiceImpl implements RetrieveService {
|
||||
|
||||
List<Long> possibleDataSets = NatureHelper.selectPossibleDataSets(originals);
|
||||
if (possibleDataSets.isEmpty()) {
|
||||
if (Objects.nonNull(queryCtx.getContextParseInfo())) {
|
||||
possibleDataSets.add(queryCtx.getContextParseInfo().getDataSetId());
|
||||
if (Objects.nonNull(queryReq.getContextParseInfo())) {
|
||||
possibleDataSets.add(queryReq.getContextParseInfo().getDataSetId());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user