[improvement][chat]Introduce separate ChatParseResp in Chat module.

This commit is contained in:
jerryjzhang
2024-10-28 12:44:39 +08:00
parent eb28d832bc
commit 14badcd4ae
14 changed files with 77 additions and 46 deletions

View File

@@ -0,0 +1,25 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseTimeCostResp;
import lombok.Data;
import java.util.List;
@Data
public class ChatParseResp {
private Long queryId;
private ParseResp.ParseState state = ParseResp.ParseState.PENDING;
private String errorMsg;
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
private ParseTimeCostResp parseTimeCost = new ParseTimeCostResp();
private List<Text2SQLExemplar> usedExemplars;
public ChatParseResp(Long queryId) {
this.queryId = queryId;
}
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
@@ -92,28 +93,27 @@ public class NL2SQLParser implements ChatQueryParser {
addDynamicExemplars(parseContext, queryNLReq);
}
ParseResp parseResp = parseContext.getResponse();
doParse(queryNLReq, parseResp);
doParse(queryNLReq, parseContext.getResponse());
}
private void processFeedback(ParseContext parseContext) {
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
ParseResp parseResp = parseContext.getResponse();
ChatParseResp parseResp = parseContext.getResponse();
for (MapModeEnum mode : MapModeEnum.values()) {
queryNLReq.setMapModeEnum(mode);
doParse(queryNLReq, parseResp);
}
}
private void doParse(QueryNLReq req, ParseResp resp) {
private void doParse(QueryNLReq req, ChatParseResp resp) {
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
ParseResp text2SqlParseResp = chatLayerService.parse(req);
if (text2SqlParseResp.getState().equals(ParseResp.ParseState.COMPLETED)) {
resp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
ParseResp parseResp = chatLayerService.parse(req);
if (parseResp.getState().equals(ParseResp.ParseState.COMPLETED)) {
resp.getSelectedParses().addAll(parseResp.getSelectedParses());
}
resp.setState(text2SqlParseResp.getState());
resp.setParseTimeCost(text2SqlParseResp.getParseTimeCost());
resp.setErrorMsg(text2SqlParseResp.getErrorMsg());
resp.setState(parseResp.getState());
resp.setParseTimeCost(parseResp.getParseTimeCost());
resp.setErrorMsg(parseResp.getErrorMsg());
}
private void rewriteMultiTurn(ParseContext parseContext, QueryNLReq queryNLReq) {

View File

@@ -4,6 +4,7 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
@@ -30,7 +31,7 @@ public interface ChatQueryRepository {
Long createChatQuery(ChatParseReq chatParseReq);
List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq, ParseResp parseResult,
List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq, ChatParseResp chatParseResp,
List<SemanticParseInfo> candidateParses);
ChatParseDO getParseInfo(Long questionId, int parseId);

View File

@@ -7,6 +7,7 @@ import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
@@ -144,10 +145,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
@Override
public List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq, ParseResp parseResult,
List<SemanticParseInfo> candidateParses) {
public List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq,
ChatParseResp chatParseResp, List<SemanticParseInfo> candidateParses) {
List<ChatParseDO> chatParseDOList = new ArrayList<>();
getChatParseDO(chatParseReq, parseResult.getQueryId(), candidateParses, chatParseDOList);
getChatParseDO(chatParseReq, chatParseResp.getQueryId(), candidateParses, chatParseDOList);
if (!CollectionUtils.isEmpty(candidateParses)) {
chatParseMapper.batchSaveParseInfo(chatParseDOList);
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.server.plugin.recognize;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
@@ -46,7 +47,7 @@ public abstract class PluginRecognizer {
public abstract PluginRecallResult recallPlugin(ParseContext parseContext);
public void buildQuery(ParseContext parseContext, ParseResp parseResp,
public void buildQuery(ParseContext parseContext, ChatParseResp parseResp,
PluginRecallResult pluginRecallResult) {
ChatPlugin plugin = pluginRecallResult.getPlugin();
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -11,13 +12,12 @@ import java.util.Objects;
@Data
public class ParseContext {
private ChatParseReq request;
private ParseResp response;
private ChatParseResp response;
private Agent agent;
private SemanticParseInfo selectedParseInfo;
public ParseContext(ChatParseReq request) {
this.request = request;
response = new ParseResp(request.getQueryText());
}
public boolean enableNL2SQL() {

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.extern.slf4j.Slf4j;
@@ -10,7 +11,7 @@ public class TimeCostProcessor implements ParseResultProcessor {
@Override
public void process(ParseContext parseContext) {
ParseResp parseResp = parseContext.getResponse();
ChatParseResp parseResp = parseContext.getResponse();
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
parseResp.getParseTimeCost().setParseTime(System.currentTimeMillis() - parseStartTime
- parseResp.getParseTimeCost().getSqlTime());

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
@@ -56,7 +57,7 @@ public class ChatQueryController {
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
chatParseReq.setUser(user);
ParseResp parseResp = chatQueryService.parse(chatParseReq);
ChatParseResp parseResp = chatQueryService.parse(chatParseReq);
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
throw new InvalidArgumentException("parser error,no selectedParses");

View File

@@ -4,6 +4,7 @@ import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
@@ -45,9 +46,9 @@ public interface ChatManageService {
void deleteQuery(Long queryId);
void updateParseCostTime(ParseResp parseResp);
void updateParseCostTime(ChatParseResp chatParseResp);
List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ParseResp parseResult);
List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ChatParseResp chatParseResp);
SemanticParseInfo getParseInfo(Long questionId, int parseId);
}

View File

@@ -3,10 +3,10 @@ package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import java.util.List;
@@ -15,7 +15,7 @@ public interface ChatQueryService {
List<SearchResult> search(ChatParseReq chatParseReq);
ParseResp parse(ChatParseReq chatParseReq);
ChatParseResp parse(ChatParseReq chatParseReq);
QueryResult execute(ChatExecuteReq chatExecuteReq) throws Exception;

View File

@@ -5,6 +5,7 @@ import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
@@ -192,16 +193,16 @@ public class ChatManageServiceImpl implements ChatManageService {
}
@Override
public void updateParseCostTime(ParseResp parseResp) {
ChatQueryDO chatQueryDO = chatQueryRepository.getChatQueryDO(parseResp.getQueryId());
chatQueryDO.setParseTimeCost(JsonUtil.toString(parseResp.getParseTimeCost()));
public void updateParseCostTime(ChatParseResp chatParseResp) {
ChatQueryDO chatQueryDO = chatQueryRepository.getChatQueryDO(chatParseResp.getQueryId());
chatQueryDO.setParseTimeCost(JsonUtil.toString(chatParseResp.getParseTimeCost()));
updateQuery(chatQueryDO);
}
@Override
public List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ParseResp parseResult) {
List<SemanticParseInfo> candidateParses = parseResult.getSelectedParses();
return chatQueryRepository.batchSaveParseInfo(chatParseReq, parseResult, candidateParses);
public List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ChatParseResp chatParseResp) {
List<SemanticParseInfo> candidateParses = chatParseResp.getSelectedParses();
return chatQueryRepository.batchSaveParseInfo(chatParseReq, chatParseResp, candidateParses);
}
private String getCurrentTime() {

View File

@@ -4,6 +4,7 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
@@ -104,11 +105,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
@Override
public ParseResp parse(ChatParseReq chatParseReq) {
public ChatParseResp parse(ChatParseReq chatParseReq) {
ParseContext parseContext = buildParseContext(chatParseReq);
ParseResp parseResp = parseContext.getResponse();
Long queryId = chatManageService.createChatQuery(chatParseReq);
parseResp.setQueryId(queryId);
parseContext.setResponse(new ChatParseResp(queryId));
for (ChatQueryParser parser : chatQueryParsers) {
parser.parse(parseContext);
@@ -118,9 +118,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
chatParseReq.setQueryText(parseContext.getRequest().getQueryText());
chatManageService.batchAddParse(chatParseReq, parseResp);
chatManageService.updateParseCostTime(parseResp);
return parseResp;
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
chatManageService.updateParseCostTime(parseContext.getResponse());
return parseContext.getResponse();
}
@Override
@@ -146,7 +146,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Override
public QueryResult parseAndExecute(ChatParseReq chatParseReq) {
ParseResp parseResp = parse(chatParseReq);
ChatParseResp parseResp = parse(chatParseReq);
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
log.debug("chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty",
chatParseReq.getChatId(), chatParseReq.getAgentId(),

View File

@@ -12,12 +12,10 @@ import java.util.stream.Collectors;
@Data
public class ParseResp {
private final String queryText;
private Long queryId;
private ParseState state = ParseState.PENDING;
private String errorMsg;
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
private ParseTimeCostResp parseTimeCost = new ParseTimeCostResp();
private List<Text2SQLExemplar> usedExemplars;
public enum ParseState {
COMPLETED, PENDING, FAILED

View File

@@ -4,6 +4,7 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.BaseApplication;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
@@ -47,13 +48,13 @@ public class BaseTest extends BaseApplication {
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId)
throws Exception {
ParseResp parseResp = submitParse(queryText, agentId, chatId);
ChatParseResp parseResp = submitParse(queryText, agentId, chatId);
SemanticParseInfo semanticParseInfo = parseResp.getSelectedParses().get(0);
ChatExecuteReq request = ChatExecuteReq.builder().queryText(parseResp.getQueryText())
.user(DataUtils.getUser()).parseId(semanticParseInfo.getId())
.queryId(parseResp.getQueryId()).chatId(chatId).agentId(agentId).saveAnswer(true)
.build();
ChatExecuteReq request =
ChatExecuteReq.builder().queryText(queryText).user(DataUtils.getUser())
.parseId(semanticParseInfo.getId()).queryId(parseResp.getQueryId())
.chatId(chatId).agentId(agentId).saveAnswer(true).build();
QueryResult queryResult = chatQueryService.execute(request);
queryResult.setChatContext(semanticParseInfo);
return queryResult;
@@ -61,10 +62,10 @@ public class BaseTest extends BaseApplication {
protected QueryResult submitNewChat(String queryText, Integer agentId) throws Exception {
int chatId = DataUtils.ONE_TURNS_CHAT_ID;
ParseResp parseResp = submitParse(queryText, agentId, chatId);
ChatParseResp parseResp = submitParse(queryText, agentId, chatId);
SemanticParseInfo parseInfo = parseResp.getSelectedParses().get(0);
ChatExecuteReq request = ChatExecuteReq.builder().queryText(parseResp.getQueryText())
ChatExecuteReq request = ChatExecuteReq.builder().queryText(queryText)
.user(DataUtils.getUser()).parseId(parseInfo.getId()).agentId(agentId)
.chatId(chatId).queryId(parseResp.getQueryId()).saveAnswer(false).build();
@@ -73,7 +74,7 @@ public class BaseTest extends BaseApplication {
return result;
}
protected ParseResp submitParse(String queryText, Integer agentId, Integer chatId) {
protected ChatParseResp submitParse(String queryText, Integer agentId, Integer chatId) {
ChatParseReq chatParseReq = DataUtils.getChatParseReq(chatId, queryText, enableLLM);
chatParseReq.setAgentId(agentId);