mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
[improvement][chat]Incorporate Response into Context objects.
This commit is contained in:
@@ -1,9 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat.server.parser;
|
package com.tencent.supersonic.chat.server.parser;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
|
|
||||||
public interface ChatQueryParser {
|
public interface ChatQueryParser {
|
||||||
|
|
||||||
void parse(ParseContext parseContext, ParseResp parseResp);
|
void parse(ParseContext parseContext);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
|||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -16,15 +15,15 @@ public class NL2PluginParser implements ChatQueryParser {
|
|||||||
ComponentFactory.getPluginRecognizers();
|
ComponentFactory.getPluginRecognizers();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
public void parse(ParseContext parseContext) {
|
||||||
if (!parseContext.getAgent().containsPluginTool()) {
|
if (!parseContext.getAgent().containsPluginTool()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||||
pluginRecognizer.recognize(parseContext, parseResp);
|
pluginRecognizer.recognize(parseContext);
|
||||||
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||||
JsonUtil.toString(parseResp));
|
JsonUtil.toString(parseContext.getResponse()));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package com.tencent.supersonic.chat.server.parser;
|
package com.tencent.supersonic.chat.server.parser;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||||
@@ -90,7 +91,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
public void parse(ParseContext parseContext) {
|
||||||
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
|
if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -98,14 +99,15 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
if (Objects.isNull(queryNLReq)) {
|
if (Objects.isNull(queryNLReq)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
ParseResp parseResp = parseContext.getResponse();
|
||||||
|
ChatParseReq parseReq = parseContext.getRequest();
|
||||||
|
|
||||||
if (!parseContext.getRequest().isDisableLLM()) {
|
if (!parseContext.getRequest().isDisableLLM()) {
|
||||||
processMultiTurn(parseContext);
|
processMultiTurn(parseContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||||
ChatContext chatCtx =
|
ChatContext chatCtx = chatContextService.getOrCreateContext(parseReq.getChatId());
|
||||||
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
|
|
||||||
if (chatCtx != null) {
|
if (chatCtx != null) {
|
||||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||||
}
|
}
|
||||||
@@ -116,7 +118,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
||||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||||
} else {
|
} else {
|
||||||
if (!parseContext.getRequest().isDisableLLM()) {
|
if (!parseReq.isDisableLLM()) {
|
||||||
parseResp.setErrorMsg(rewriteErrorMessage(parseContext,
|
parseResp.setErrorMsg(rewriteErrorMessage(parseContext,
|
||||||
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars()));
|
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars()));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,14 +7,14 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|||||||
public class PlainTextParser implements ChatQueryParser {
|
public class PlainTextParser implements ChatQueryParser {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
public void parse(ParseContext parseContext) {
|
||||||
if (parseContext.getAgent().containsAnyTool()) {
|
if (parseContext.getAgent().containsAnyTool()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||||
parseInfo.setQueryMode("PLAIN_TEXT");
|
parseInfo.setQueryMode("PLAIN_TEXT");
|
||||||
parseResp.getSelectedParses().add(parseInfo);
|
parseContext.getResponse().getSelectedParses().add(parseInfo);
|
||||||
parseResp.setState(ParseResp.ParseState.COMPLETED);
|
parseContext.getResponse().setState(ParseResp.ParseState.COMPLETED);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ import java.util.Set;
|
|||||||
/** PluginParser defines the basic process and common methods for recalling plugins. */
|
/** PluginParser defines the basic process and common methods for recalling plugins. */
|
||||||
public abstract class PluginRecognizer {
|
public abstract class PluginRecognizer {
|
||||||
|
|
||||||
public void recognize(ParseContext parseContext, ParseResp parseResp) {
|
public void recognize(ParseContext parseContext) {
|
||||||
if (!checkPreCondition(parseContext)) {
|
if (!checkPreCondition(parseContext)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -39,7 +39,7 @@ public abstract class PluginRecognizer {
|
|||||||
if (pluginRecallResult == null) {
|
if (pluginRecallResult == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
buildQuery(parseContext, parseResp, pluginRecallResult);
|
buildQuery(parseContext, parseContext.getResponse(), pluginRecallResult);
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract boolean checkPreCondition(ParseContext parseContext);
|
public abstract boolean checkPreCondition(ParseContext parseContext);
|
||||||
|
|||||||
@@ -2,15 +2,18 @@ package com.tencent.supersonic.chat.server.pojo;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ParseContext {
|
public class ParseContext {
|
||||||
private ChatParseReq request;
|
private ChatParseReq request;
|
||||||
|
private ParseResp response;
|
||||||
private Agent agent;
|
private Agent agent;
|
||||||
|
|
||||||
public ParseContext(ChatParseReq request) {
|
public ParseContext(ChatParseReq request) {
|
||||||
this.request = request;
|
this.request = request;
|
||||||
|
response = new ParseResp(request.getQueryText());
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean enableNL2SQL() {
|
public boolean enableNL2SQL() {
|
||||||
|
|||||||
@@ -2,10 +2,9 @@ package com.tencent.supersonic.chat.server.processor.parse;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
|
|
||||||
/** A ParseResultProcessor wraps things up before returning parsing results to the users. */
|
/** A ParseResultProcessor wraps things up before returning parsing results to the users. */
|
||||||
public interface ParseResultProcessor extends ResultProcessor {
|
public interface ParseResultProcessor extends ResultProcessor {
|
||||||
|
|
||||||
void process(ParseContext parseContext, ParseResp parseResp);
|
void process(ParseContext parseContext);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import com.tencent.supersonic.common.config.EmbeddingConfig;
|
|||||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.service.ExemplarService;
|
import com.tencent.supersonic.common.service.ExemplarService;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
import lombok.SneakyThrows;
|
import lombok.SneakyThrows;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
@@ -23,13 +22,13 @@ import java.util.stream.Collectors;
|
|||||||
public class QueryRecommendProcessor implements ParseResultProcessor {
|
public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
public void process(ParseContext parseContext) {
|
||||||
CompletableFuture.runAsync(() -> doProcess(parseResp, parseContext));
|
CompletableFuture.runAsync(() -> doProcess(parseContext));
|
||||||
}
|
}
|
||||||
|
|
||||||
@SneakyThrows
|
@SneakyThrows
|
||||||
private void doProcess(ParseResp parseResp, ParseContext parseContext) {
|
private void doProcess(ParseContext parseContext) {
|
||||||
Long queryId = parseResp.getQueryId();
|
Long queryId = parseContext.getResponse().getQueryId();
|
||||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(
|
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(
|
||||||
parseContext.getRequest().getQueryText(), parseContext.getAgent().getId());
|
parseContext.getRequest().getQueryText(), parseContext.getAgent().getId());
|
||||||
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
public class TimeCostProcessor implements ParseResultProcessor {
|
public class TimeCostProcessor implements ParseResultProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
public void process(ParseContext parseContext) {
|
||||||
|
ParseResp parseResp = parseContext.getResponse();
|
||||||
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
|
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
|
||||||
parseResp.getParseTimeCost().setParseTime(System.currentTimeMillis() - parseStartTime
|
parseResp.getParseTimeCost().setParseTime(System.currentTimeMillis() - parseStartTime
|
||||||
- parseResp.getParseTimeCost().getSqlTime());
|
- parseResp.getParseTimeCost().getSqlTime());
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ public interface ChatManageService {
|
|||||||
|
|
||||||
PageInfo<QueryResp> queryInfo(PageQueryInfoReq pageQueryInfoReq, long chatId);
|
PageInfo<QueryResp> queryInfo(PageQueryInfoReq pageQueryInfoReq, long chatId);
|
||||||
|
|
||||||
void createChatQuery(ChatParseReq chatParseReq, ParseResp parseResp);
|
Long createChatQuery(ChatParseReq chatParseReq);
|
||||||
|
|
||||||
QueryResp getChatQuery(Long queryId);
|
QueryResp getChatQuery(Long queryId);
|
||||||
|
|
||||||
|
|||||||
@@ -93,9 +93,8 @@ public class ChatManageServiceImpl implements ChatManageService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void createChatQuery(ChatParseReq chatParseReq, ParseResp parseResp) {
|
public Long createChatQuery(ChatParseReq chatParseReq) {
|
||||||
Long queryId = chatQueryRepository.createChatQuery(chatParseReq);
|
return chatQueryRepository.createChatQuery(chatParseReq);
|
||||||
parseResp.setQueryId(queryId);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -105,15 +105,16 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ParseResp parse(ChatParseReq chatParseReq) {
|
public ParseResp parse(ChatParseReq chatParseReq) {
|
||||||
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
|
|
||||||
chatManageService.createChatQuery(chatParseReq, parseResp);
|
|
||||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||||
|
ParseResp parseResp = parseContext.getResponse();
|
||||||
|
Long queryId = chatManageService.createChatQuery(chatParseReq);
|
||||||
|
parseResp.setQueryId(queryId);
|
||||||
|
|
||||||
for (ChatQueryParser parser : chatQueryParsers) {
|
for (ChatQueryParser parser : chatQueryParsers) {
|
||||||
parser.parse(parseContext, parseResp);
|
parser.parse(parseContext);
|
||||||
}
|
}
|
||||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||||
processor.process(parseContext, parseResp);
|
processor.process(parseContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
chatParseReq.setQueryText(parseContext.getRequest().getQueryText());
|
chatParseReq.setQueryText(parseContext.getRequest().getQueryText());
|
||||||
|
|||||||
@@ -120,7 +120,7 @@ public class RetrieveServiceImpl implements RetrieveService {
|
|||||||
return searchResults.stream().limit(RESULT_SIZE).collect(Collectors.toList());
|
return searchResults.stream().limit(RESULT_SIZE).collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<Long> getPossibleDataSets(QueryNLReq queryCtx, List<S2Term> originals,
|
private List<Long> getPossibleDataSets(QueryNLReq queryReq, List<S2Term> originals,
|
||||||
Set<Long> dataSetIds) {
|
Set<Long> dataSetIds) {
|
||||||
if (CollectionUtils.isNotEmpty(dataSetIds)) {
|
if (CollectionUtils.isNotEmpty(dataSetIds)) {
|
||||||
return new ArrayList<>(dataSetIds);
|
return new ArrayList<>(dataSetIds);
|
||||||
@@ -128,8 +128,8 @@ public class RetrieveServiceImpl implements RetrieveService {
|
|||||||
|
|
||||||
List<Long> possibleDataSets = NatureHelper.selectPossibleDataSets(originals);
|
List<Long> possibleDataSets = NatureHelper.selectPossibleDataSets(originals);
|
||||||
if (possibleDataSets.isEmpty()) {
|
if (possibleDataSets.isEmpty()) {
|
||||||
if (Objects.nonNull(queryCtx.getContextParseInfo())) {
|
if (Objects.nonNull(queryReq.getContextParseInfo())) {
|
||||||
possibleDataSets.add(queryCtx.getContextParseInfo().getDataSetId());
|
possibleDataSets.add(queryReq.getContextParseInfo().getDataSetId());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user