(improvment)(chat) optimize parse performance (#197)

* (improvment)(chat) optimize parse performance
---------

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2023-10-12 11:51:57 +08:00
committed by GitHub
parent b753eda9b9
commit e7b8c68dba
26 changed files with 214 additions and 123 deletions

View File

@@ -26,7 +26,8 @@ public class EntityInfoParseResponder implements ParseResponder {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryReq.getUser());
if (QueryManager.isEntityQuery(parseInfo.getQueryMode())) {
if (QueryManager.isEntityQuery(parseInfo.getQueryMode())
|| QueryManager.isMetricQuery(parseInfo.getQueryMode())) {
parseInfo.setEntityInfo(entityInfo);
}
//2. set native value

View File

@@ -1,23 +0,0 @@
package com.tencent.supersonic.chat.responder.parse;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
@Slf4j
public class SolvedQueryParseResponder implements ParseResponder {
@Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext) {
SolvedQueryManager solvedQueryManager = ContextUtils.getBean(SolvedQueryManager.class);
List<SolvedQueryRecallResp> solvedQueryRecallResps =
solvedQueryManager.recallSolvedQuery(queryContext.getRequest().getQueryText(),
queryContext.getRequest().getAgentId());
parseResp.setSimilarSolvedQuery(solvedQueryRecallResps);
}
}

View File

@@ -4,6 +4,7 @@ package com.tencent.supersonic.chat.rest;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
@@ -85,4 +86,10 @@ public class ChatController {
return chatService.queryShowCase(pageQueryInfoCommand, agentId);
}
@RequestMapping("/getSolvedQuery")
public List<SolvedQueryRecallResp> getSolvedQuery(@RequestParam(value = "queryText") String queryText,
@RequestParam(value = "agentId") Integer agentId) {
return chatService.getSolvedQuery(queryText, agentId);
}
}

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
@@ -22,23 +23,23 @@ public interface ChatService {
* @param chatId
* @return
*/
public Long getContextModel(Integer chatId);
Long getContextModel(Integer chatId);
public ChatContext getOrCreateContext(int chatId);
ChatContext getOrCreateContext(int chatId);
public void updateContext(ChatContext chatCtx);
void updateContext(ChatContext chatCtx);
public void switchContext(ChatContext chatCtx);
void switchContext(ChatContext chatCtx);
public Boolean addChat(User user, String chatName, Integer agentId);
Boolean addChat(User user, String chatName, Integer agentId);
public List<ChatDO> getAll(String userName, Integer agentId);
List<ChatDO> getAll(String userName, Integer agentId);
public boolean updateChatName(Long chatId, String chatName, String userName);
boolean updateChatName(Long chatId, String chatName, String userName);
public boolean updateFeedback(Integer id, Integer score, String feedback);
boolean updateFeedback(Integer id, Integer score, String feedback);
public boolean updateChatIsTop(Long chatId, int isTop);
boolean updateChatIsTop(Long chatId, int isTop);
Boolean deleteChat(Long chatId, String userName);
@@ -46,20 +47,22 @@ public interface ChatService {
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
public void addQuery(QueryResult queryResult, ChatContext chatCtx);
void addQuery(QueryResult queryResult, ChatContext chatCtx);
public void batchAddParse(ChatContext chatCtx, QueryReq queryReq,
void batchAddParse(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses,
List<SemanticParseInfo> selectedParses);
public ChatQueryDO getLastQuery(long chatId);
ChatQueryDO getLastQuery(long chatId);
public int updateQuery(ChatQueryDO chatQueryDO);
int updateQuery(ChatQueryDO chatQueryDO);
public Boolean updateQuery(Long questionId, QueryResult queryResult, ChatContext chatCtx);
Boolean updateQuery(Long questionId, QueryResult queryResult, ChatContext chatCtx);
public ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
ChatParseDO getParseInfo(Long questionId, String userName, int parseId);
public Boolean deleteChatQuery(Long questionId);
Boolean deleteChatQuery(Long questionId);
List<SolvedQueryRecallResp> getSolvedQuery(String queryText, Integer agentId);
}

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
@@ -25,6 +26,7 @@ import java.util.Objects;
import java.util.stream.Collectors;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Primary;
@@ -38,12 +40,14 @@ public class ChatServiceImpl implements ChatService {
private ChatContextRepository chatContextRepository;
private ChatRepository chatRepository;
private ChatQueryRepository chatQueryRepository;
private SolvedQueryManager solvedQueryManager;
public ChatServiceImpl(ChatContextRepository chatContextRepository, ChatRepository chatRepository,
ChatQueryRepository chatQueryRepository) {
ChatQueryRepository chatQueryRepository, SolvedQueryManager solvedQueryManager) {
this.chatContextRepository = chatContextRepository;
this.chatRepository = chatRepository;
this.chatQueryRepository = chatQueryRepository;
this.solvedQueryManager = solvedQueryManager;
}
@Override
@@ -192,4 +196,9 @@ public class ChatServiceImpl implements ChatService {
return chatQueryRepository.deleteChatQuery(questionId);
}
@Override
public List<SolvedQueryRecallResp> getSolvedQuery(String queryText, Integer agentId) {
return solvedQueryManager.recallSolvedQuery(queryText, agentId);
}
}

View File

@@ -96,7 +96,7 @@ public class QueryServiceImpl implements QueryService {
// in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatService.getOrCreateContext(queryReq.getChatId());
List<StatisticsDO> timeCostDOList = new ArrayList<>();
schemaMappers.stream().forEach(mapper -> {
for (SchemaMapper mapper : schemaMappers) {
Long startTime = System.currentTimeMillis();
mapper.map(queryCtx);
Long endTime = System.currentTimeMillis();
@@ -104,8 +104,8 @@ public class QueryServiceImpl implements QueryService {
timeCostDOList.add(StatisticsDO.builder().cost((int) (endTime - startTime))
.interfaceName(className).type(CostType.MAPPER.getType()).build());
log.info("{} result:{}", className, JsonUtil.toString(queryCtx));
});
semanticParsers.stream().forEach(parser -> {
}
for (SemanticParser parser : semanticParsers) {
Long startTime = System.currentTimeMillis();
parser.parse(queryCtx, chatCtx);
Long endTime = System.currentTimeMillis();
@@ -113,7 +113,7 @@ public class QueryServiceImpl implements QueryService {
timeCostDOList.add(StatisticsDO.builder().cost((int) (endTime - startTime))
.interfaceName(className).type(CostType.PARSER.getType()).build());
log.info("{} result:{}", className, JsonUtil.toString(queryCtx));
});
}
ParseResp parseResult;
if (queryCtx.getCandidateQueries().size() > 0) {
log.debug("pick before [{}]", queryCtx.getCandidateQueries().stream().collect(
@@ -122,12 +122,8 @@ public class QueryServiceImpl implements QueryService {
log.debug("pick after [{}]", selectedQueries.stream().collect(
Collectors.toList()));
List<SemanticParseInfo> selectedParses = selectedQueries.stream()
.map(SemanticQuery::getParseInfo)
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
.collect(Collectors.toList());
List<SemanticParseInfo> candidateParses = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
List<SemanticParseInfo> selectedParses = convertParseInfo(selectedQueries);
List<SemanticParseInfo> candidateParses = convertParseInfo(queryCtx.getCandidateQueries());
parseResult = ParseResp.builder()
.chatId(queryReq.getChatId())
.queryText(queryReq.getQueryText())
@@ -151,6 +147,13 @@ public class QueryServiceImpl implements QueryService {
return parseResult;
}
private List<SemanticParseInfo> convertParseInfo(List<SemanticQuery> semanticQueries) {
return semanticQueries.stream()
.map(SemanticQuery::getParseInfo)
.sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
.collect(Collectors.toList());
}
@Override
public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception {
ChatParseDO chatParseDO = chatService.getParseInfo(queryReq.getQueryId(),