[improvement][chat]Incorporate Response into Context objects.

This commit is contained in:
jerryjzhang
2024-10-27 17:44:29 +08:00
parent 1842261dfe
commit 3e0f724e97
13 changed files with 36 additions and 34 deletions

View File

@@ -1,9 +1,8 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public interface ChatQueryParser {
void parse(ParseContext parseContext, ParseResp parseResp);
void parse(ParseContext parseContext);
}

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
@@ -16,15 +15,15 @@ public class NL2PluginParser implements ChatQueryParser {
ComponentFactory.getPluginRecognizers();
@Override
public void parse(ParseContext parseContext, ParseResp parseResp) {
public void parse(ParseContext parseContext) {
if (!parseContext.getAgent().containsPluginTool()) {
return;
}
pluginRecognizers.forEach(pluginRecognizer -> {
pluginRecognizer.recognize(parseContext, parseResp);
pluginRecognizer.recognize(parseContext);
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
JsonUtil.toString(parseResp));
JsonUtil.toString(parseContext.getResponse()));
});
}
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
@@ -90,7 +91,7 @@ public class NL2SQLParser implements ChatQueryParser {
}
@Override
public void parse(ParseContext parseContext, ParseResp parseResp) {
public void parse(ParseContext parseContext) {
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
return;
}
@@ -98,14 +99,15 @@ public class NL2SQLParser implements ChatQueryParser {
if (Objects.isNull(queryNLReq)) {
return;
}
ParseResp parseResp = parseContext.getResponse();
ChatParseReq parseReq = parseContext.getRequest();
if (!parseContext.getRequest().isDisableLLM()) {
processMultiTurn(parseContext);
}
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx =
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
ChatContext chatCtx = chatContextService.getOrCreateContext(parseReq.getChatId());
if (chatCtx != null) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
}
@@ -116,7 +118,7 @@ public class NL2SQLParser implements ChatQueryParser {
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else {
if (!parseContext.getRequest().isDisableLLM()) {
if (!parseReq.isDisableLLM()) {
parseResp.setErrorMsg(rewriteErrorMessage(parseContext,
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars()));
}

View File

@@ -7,14 +7,14 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public class PlainTextParser implements ChatQueryParser {
@Override
public void parse(ParseContext parseContext, ParseResp parseResp) {
public void parse(ParseContext parseContext) {
if (parseContext.getAgent().containsAnyTool()) {
return;
}
SemanticParseInfo parseInfo = new SemanticParseInfo();
parseInfo.setQueryMode("PLAIN_TEXT");
parseResp.getSelectedParses().add(parseInfo);
parseResp.setState(ParseResp.ParseState.COMPLETED);
parseContext.getResponse().getSelectedParses().add(parseInfo);
parseContext.getResponse().setState(ParseResp.ParseState.COMPLETED);
}
}

View File

@@ -31,7 +31,7 @@ import java.util.Set;
/** PluginParser defines the basic process and common methods for recalling plugins. */
public abstract class PluginRecognizer {
public void recognize(ParseContext parseContext, ParseResp parseResp) {
public void recognize(ParseContext parseContext) {
if (!checkPreCondition(parseContext)) {
return;
}
@@ -39,7 +39,7 @@ public abstract class PluginRecognizer {
if (pluginRecallResult == null) {
return;
}
buildQuery(parseContext, parseResp, pluginRecallResult);
buildQuery(parseContext, parseContext.getResponse(), pluginRecallResult);
}
public abstract boolean checkPreCondition(ParseContext parseContext);

View File

@@ -2,15 +2,18 @@ package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.Data;
@Data
public class ParseContext {
private ChatParseReq request;
private ParseResp response;
private Agent agent;
public ParseContext(ChatParseReq request) {
this.request = request;
response = new ParseResp(request.getQueryText());
}
public boolean enableNL2SQL() {

View File

@@ -2,10 +2,9 @@ package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
/** A ParseResultProcessor wraps things up before returning parsing results to the users. */
public interface ParseResultProcessor extends ResultProcessor {
void process(ParseContext parseContext, ParseResp parseResp);
void process(ParseContext parseContext);
}

View File

@@ -10,7 +10,6 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
@@ -23,13 +22,13 @@ import java.util.stream.Collectors;
public class QueryRecommendProcessor implements ParseResultProcessor {
@Override
public void process(ParseContext parseContext, ParseResp parseResp) {
CompletableFuture.runAsync(() -> doProcess(parseResp, parseContext));
public void process(ParseContext parseContext) {
CompletableFuture.runAsync(() -> doProcess(parseContext));
}
@SneakyThrows
private void doProcess(ParseResp parseResp, ParseContext parseContext) {
Long queryId = parseResp.getQueryId();
private void doProcess(ParseContext parseContext) {
Long queryId = parseContext.getResponse().getQueryId();
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(
parseContext.getRequest().getQueryText(), parseContext.getAgent().getId());
ChatQueryDO chatQueryDO = getChatQuery(queryId);

View File

@@ -9,7 +9,8 @@ import lombok.extern.slf4j.Slf4j;
public class TimeCostProcessor implements ParseResultProcessor {
@Override
public void process(ParseContext parseContext, ParseResp parseResp) {
public void process(ParseContext parseContext) {
ParseResp parseResp = parseContext.getResponse();
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
parseResp.getParseTimeCost().setParseTime(System.currentTimeMillis() - parseStartTime
- parseResp.getParseTimeCost().getSqlTime());

View File

@@ -31,7 +31,7 @@ public interface ChatManageService {
PageInfo<QueryResp> queryInfo(PageQueryInfoReq pageQueryInfoReq, long chatId);
void createChatQuery(ChatParseReq chatParseReq, ParseResp parseResp);
Long createChatQuery(ChatParseReq chatParseReq);
QueryResp getChatQuery(Long queryId);

View File

@@ -93,9 +93,8 @@ public class ChatManageServiceImpl implements ChatManageService {
}
@Override
public void createChatQuery(ChatParseReq chatParseReq, ParseResp parseResp) {
Long queryId = chatQueryRepository.createChatQuery(chatParseReq);
parseResp.setQueryId(queryId);
public Long createChatQuery(ChatParseReq chatParseReq) {
return chatQueryRepository.createChatQuery(chatParseReq);
}
@Override

View File

@@ -105,15 +105,16 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Override
public ParseResp parse(ChatParseReq chatParseReq) {
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
chatManageService.createChatQuery(chatParseReq, parseResp);
ParseContext parseContext = buildParseContext(chatParseReq);
ParseResp parseResp = parseContext.getResponse();
Long queryId = chatManageService.createChatQuery(chatParseReq);
parseResp.setQueryId(queryId);
for (ChatQueryParser parser : chatQueryParsers) {
parser.parse(parseContext, parseResp);
parser.parse(parseContext);
}
for (ParseResultProcessor processor : parseResultProcessors) {
processor.process(parseContext, parseResp);
processor.process(parseContext);
}
chatParseReq.setQueryText(parseContext.getRequest().getQueryText());

View File

@@ -120,7 +120,7 @@ public class RetrieveServiceImpl implements RetrieveService {
return searchResults.stream().limit(RESULT_SIZE).collect(Collectors.toList());
}
private List<Long> getPossibleDataSets(QueryNLReq queryCtx, List<S2Term> originals,
private List<Long> getPossibleDataSets(QueryNLReq queryReq, List<S2Term> originals,
Set<Long> dataSetIds) {
if (CollectionUtils.isNotEmpty(dataSetIds)) {
return new ArrayList<>(dataSetIds);
@@ -128,8 +128,8 @@ public class RetrieveServiceImpl implements RetrieveService {
List<Long> possibleDataSets = NatureHelper.selectPossibleDataSets(originals);
if (possibleDataSets.isEmpty()) {
if (Objects.nonNull(queryCtx.getContextParseInfo())) {
possibleDataSets.add(queryCtx.getContextParseInfo().getDataSetId());
if (Objects.nonNull(queryReq.getContextParseInfo())) {
possibleDataSets.add(queryReq.getContextParseInfo().getDataSetId());
}
}