mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][chat]Introduce separate ChatParseResp in Chat module.
This commit is contained in:
@@ -0,0 +1,25 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseTimeCostResp;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ChatParseResp {
|
||||
|
||||
private Long queryId;
|
||||
private ParseResp.ParseState state = ParseResp.ParseState.PENDING;
|
||||
private String errorMsg;
|
||||
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
|
||||
private ParseTimeCostResp parseTimeCost = new ParseTimeCostResp();
|
||||
private List<Text2SQLExemplar> usedExemplars;
|
||||
|
||||
public ChatParseResp(Long queryId) {
|
||||
this.queryId = queryId;
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
@@ -92,28 +93,27 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
addDynamicExemplars(parseContext, queryNLReq);
|
||||
}
|
||||
|
||||
ParseResp parseResp = parseContext.getResponse();
|
||||
doParse(queryNLReq, parseResp);
|
||||
doParse(queryNLReq, parseContext.getResponse());
|
||||
}
|
||||
|
||||
private void processFeedback(ParseContext parseContext) {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
ParseResp parseResp = parseContext.getResponse();
|
||||
ChatParseResp parseResp = parseContext.getResponse();
|
||||
for (MapModeEnum mode : MapModeEnum.values()) {
|
||||
queryNLReq.setMapModeEnum(mode);
|
||||
doParse(queryNLReq, parseResp);
|
||||
}
|
||||
}
|
||||
|
||||
private void doParse(QueryNLReq req, ParseResp resp) {
|
||||
private void doParse(QueryNLReq req, ChatParseResp resp) {
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
ParseResp text2SqlParseResp = chatLayerService.parse(req);
|
||||
if (text2SqlParseResp.getState().equals(ParseResp.ParseState.COMPLETED)) {
|
||||
resp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||
ParseResp parseResp = chatLayerService.parse(req);
|
||||
if (parseResp.getState().equals(ParseResp.ParseState.COMPLETED)) {
|
||||
resp.getSelectedParses().addAll(parseResp.getSelectedParses());
|
||||
}
|
||||
resp.setState(text2SqlParseResp.getState());
|
||||
resp.setParseTimeCost(text2SqlParseResp.getParseTimeCost());
|
||||
resp.setErrorMsg(text2SqlParseResp.getErrorMsg());
|
||||
resp.setState(parseResp.getState());
|
||||
resp.setParseTimeCost(parseResp.getParseTimeCost());
|
||||
resp.setErrorMsg(parseResp.getErrorMsg());
|
||||
}
|
||||
|
||||
private void rewriteMultiTurn(ParseContext parseContext, QueryNLReq queryNLReq) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
@@ -30,7 +31,7 @@ public interface ChatQueryRepository {
|
||||
|
||||
Long createChatQuery(ChatParseReq chatParseReq);
|
||||
|
||||
List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq, ParseResp parseResult,
|
||||
List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq, ChatParseResp chatParseResp,
|
||||
List<SemanticParseInfo> candidateParses);
|
||||
|
||||
ChatParseDO getParseInfo(Long questionId, int parseId);
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.github.pagehelper.PageHelper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
@@ -144,10 +145,10 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq, ParseResp parseResult,
|
||||
List<SemanticParseInfo> candidateParses) {
|
||||
public List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq,
|
||||
ChatParseResp chatParseResp, List<SemanticParseInfo> candidateParses) {
|
||||
List<ChatParseDO> chatParseDOList = new ArrayList<>();
|
||||
getChatParseDO(chatParseReq, parseResult.getQueryId(), candidateParses, chatParseDOList);
|
||||
getChatParseDO(chatParseReq, chatParseResp.getQueryId(), candidateParses, chatParseDOList);
|
||||
if (!CollectionUtils.isEmpty(candidateParses)) {
|
||||
chatParseMapper.batchSaveParseInfo(chatParseDOList);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.server.plugin.recognize;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
@@ -46,7 +47,7 @@ public abstract class PluginRecognizer {
|
||||
|
||||
public abstract PluginRecallResult recallPlugin(ParseContext parseContext);
|
||||
|
||||
public void buildQuery(ParseContext parseContext, ParseResp parseResp,
|
||||
public void buildQuery(ParseContext parseContext, ChatParseResp parseResp,
|
||||
PluginRecallResult pluginRecallResult) {
|
||||
ChatPlugin plugin = pluginRecallResult.getPlugin();
|
||||
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -11,13 +12,12 @@ import java.util.Objects;
|
||||
@Data
|
||||
public class ParseContext {
|
||||
private ChatParseReq request;
|
||||
private ParseResp response;
|
||||
private ChatParseResp response;
|
||||
private Agent agent;
|
||||
private SemanticParseInfo selectedParseInfo;
|
||||
|
||||
public ParseContext(ChatParseReq request) {
|
||||
this.request = request;
|
||||
response = new ParseResp(request.getQueryText());
|
||||
}
|
||||
|
||||
public boolean enableNL2SQL() {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -10,7 +11,7 @@ public class TimeCostProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseContext parseContext) {
|
||||
ParseResp parseResp = parseContext.getResponse();
|
||||
ChatParseResp parseResp = parseContext.getResponse();
|
||||
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
|
||||
parseResp.getParseTimeCost().setParseTime(System.currentTimeMillis() - parseStartTime
|
||||
- parseResp.getParseTimeCost().getSqlTime());
|
||||
|
||||
@@ -8,6 +8,7 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
@@ -56,7 +57,7 @@ public class ChatQueryController {
|
||||
HttpServletResponse response) throws Exception {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
chatParseReq.setUser(user);
|
||||
ParseResp parseResp = chatQueryService.parse(chatParseReq);
|
||||
ChatParseResp parseResp = chatQueryService.parse(chatParseReq);
|
||||
|
||||
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
||||
throw new InvalidArgumentException("parser error,no selectedParses");
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
@@ -45,9 +46,9 @@ public interface ChatManageService {
|
||||
|
||||
void deleteQuery(Long queryId);
|
||||
|
||||
void updateParseCostTime(ParseResp parseResp);
|
||||
void updateParseCostTime(ChatParseResp chatParseResp);
|
||||
|
||||
List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ParseResp parseResult);
|
||||
List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ChatParseResp chatParseResp);
|
||||
|
||||
SemanticParseInfo getParseInfo(Long questionId, int parseId);
|
||||
}
|
||||
|
||||
@@ -3,10 +3,10 @@ package com.tencent.supersonic.chat.server.service;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
|
||||
import java.util.List;
|
||||
@@ -15,7 +15,7 @@ public interface ChatQueryService {
|
||||
|
||||
List<SearchResult> search(ChatParseReq chatParseReq);
|
||||
|
||||
ParseResp parse(ChatParseReq chatParseReq);
|
||||
ChatParseResp parse(ChatParseReq chatParseReq);
|
||||
|
||||
QueryResult execute(ChatExecuteReq chatExecuteReq) throws Exception;
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
@@ -192,16 +193,16 @@ public class ChatManageServiceImpl implements ChatManageService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateParseCostTime(ParseResp parseResp) {
|
||||
ChatQueryDO chatQueryDO = chatQueryRepository.getChatQueryDO(parseResp.getQueryId());
|
||||
chatQueryDO.setParseTimeCost(JsonUtil.toString(parseResp.getParseTimeCost()));
|
||||
public void updateParseCostTime(ChatParseResp chatParseResp) {
|
||||
ChatQueryDO chatQueryDO = chatQueryRepository.getChatQueryDO(chatParseResp.getQueryId());
|
||||
chatQueryDO.setParseTimeCost(JsonUtil.toString(chatParseResp.getParseTimeCost()));
|
||||
updateQuery(chatQueryDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ParseResp parseResult) {
|
||||
List<SemanticParseInfo> candidateParses = parseResult.getSelectedParses();
|
||||
return chatQueryRepository.batchSaveParseInfo(chatParseReq, parseResult, candidateParses);
|
||||
public List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ChatParseResp chatParseResp) {
|
||||
List<SemanticParseInfo> candidateParses = chatParseResp.getSelectedParses();
|
||||
return chatQueryRepository.batchSaveParseInfo(chatParseReq, chatParseResp, candidateParses);
|
||||
}
|
||||
|
||||
private String getCurrentTime() {
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
|
||||
@@ -104,11 +105,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParseResp parse(ChatParseReq chatParseReq) {
|
||||
public ChatParseResp parse(ChatParseReq chatParseReq) {
|
||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||
ParseResp parseResp = parseContext.getResponse();
|
||||
Long queryId = chatManageService.createChatQuery(chatParseReq);
|
||||
parseResp.setQueryId(queryId);
|
||||
parseContext.setResponse(new ChatParseResp(queryId));
|
||||
|
||||
for (ChatQueryParser parser : chatQueryParsers) {
|
||||
parser.parse(parseContext);
|
||||
@@ -118,9 +118,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
}
|
||||
|
||||
chatParseReq.setQueryText(parseContext.getRequest().getQueryText());
|
||||
chatManageService.batchAddParse(chatParseReq, parseResp);
|
||||
chatManageService.updateParseCostTime(parseResp);
|
||||
return parseResp;
|
||||
chatManageService.batchAddParse(chatParseReq, parseContext.getResponse());
|
||||
chatManageService.updateParseCostTime(parseContext.getResponse());
|
||||
return parseContext.getResponse();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -146,7 +146,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
@Override
|
||||
public QueryResult parseAndExecute(ChatParseReq chatParseReq) {
|
||||
ParseResp parseResp = parse(chatParseReq);
|
||||
ChatParseResp parseResp = parse(chatParseReq);
|
||||
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
||||
log.debug("chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty",
|
||||
chatParseReq.getChatId(), chatParseReq.getAgentId(),
|
||||
|
||||
@@ -12,12 +12,10 @@ import java.util.stream.Collectors;
|
||||
@Data
|
||||
public class ParseResp {
|
||||
private final String queryText;
|
||||
private Long queryId;
|
||||
private ParseState state = ParseState.PENDING;
|
||||
private String errorMsg;
|
||||
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
|
||||
private ParseTimeCostResp parseTimeCost = new ParseTimeCostResp();
|
||||
private List<Text2SQLExemplar> usedExemplars;
|
||||
|
||||
public enum ParseState {
|
||||
COMPLETED, PENDING, FAILED
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.BaseApplication;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
@@ -47,13 +48,13 @@ public class BaseTest extends BaseApplication {
|
||||
|
||||
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId)
|
||||
throws Exception {
|
||||
ParseResp parseResp = submitParse(queryText, agentId, chatId);
|
||||
ChatParseResp parseResp = submitParse(queryText, agentId, chatId);
|
||||
|
||||
SemanticParseInfo semanticParseInfo = parseResp.getSelectedParses().get(0);
|
||||
ChatExecuteReq request = ChatExecuteReq.builder().queryText(parseResp.getQueryText())
|
||||
.user(DataUtils.getUser()).parseId(semanticParseInfo.getId())
|
||||
.queryId(parseResp.getQueryId()).chatId(chatId).agentId(agentId).saveAnswer(true)
|
||||
.build();
|
||||
ChatExecuteReq request =
|
||||
ChatExecuteReq.builder().queryText(queryText).user(DataUtils.getUser())
|
||||
.parseId(semanticParseInfo.getId()).queryId(parseResp.getQueryId())
|
||||
.chatId(chatId).agentId(agentId).saveAnswer(true).build();
|
||||
QueryResult queryResult = chatQueryService.execute(request);
|
||||
queryResult.setChatContext(semanticParseInfo);
|
||||
return queryResult;
|
||||
@@ -61,10 +62,10 @@ public class BaseTest extends BaseApplication {
|
||||
|
||||
protected QueryResult submitNewChat(String queryText, Integer agentId) throws Exception {
|
||||
int chatId = DataUtils.ONE_TURNS_CHAT_ID;
|
||||
ParseResp parseResp = submitParse(queryText, agentId, chatId);
|
||||
ChatParseResp parseResp = submitParse(queryText, agentId, chatId);
|
||||
|
||||
SemanticParseInfo parseInfo = parseResp.getSelectedParses().get(0);
|
||||
ChatExecuteReq request = ChatExecuteReq.builder().queryText(parseResp.getQueryText())
|
||||
ChatExecuteReq request = ChatExecuteReq.builder().queryText(queryText)
|
||||
.user(DataUtils.getUser()).parseId(parseInfo.getId()).agentId(agentId)
|
||||
.chatId(chatId).queryId(parseResp.getQueryId()).saveAnswer(false).build();
|
||||
|
||||
@@ -73,7 +74,7 @@ public class BaseTest extends BaseApplication {
|
||||
return result;
|
||||
}
|
||||
|
||||
protected ParseResp submitParse(String queryText, Integer agentId, Integer chatId) {
|
||||
protected ChatParseResp submitParse(String queryText, Integer agentId, Integer chatId) {
|
||||
|
||||
ChatParseReq chatParseReq = DataUtils.getChatParseReq(chatId, queryText, enableLLM);
|
||||
chatParseReq.setAgentId(agentId);
|
||||
|
||||
Reference in New Issue
Block a user