From 31c8fea2dc46aa20ec7cb5a35322c5ce43e17df5 Mon Sep 17 00:00:00 2001 From: jolunoluo Date: Tue, 19 Sep 2023 15:36:15 +0800 Subject: [PATCH] (improvement)(chat) add QueryResponder to recall history similar solved query --- .../chat/api/pojo/response/ParseResp.java | 1 + .../pojo/response/SolvedQueryRecallResp.java | 17 +++ .../plugin/embedding/EmbeddingConfig.java | 6 + .../plugin/embedding/RecallRetrieval.java | 2 + .../queryresponder/DefaultQueryResponder.java | 141 ++++++++++++++++++ .../chat/queryresponder/QueryResponder.java | 12 ++ .../chat/service/impl/QueryServiceImpl.java | 8 + 7 files changed, 187 insertions(+) create mode 100644 chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/SolvedQueryRecallResp.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/QueryResponder.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java index e8fc10db8..9af3946a9 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java @@ -21,6 +21,7 @@ public class ParseResp { private ParseState state; private List selectedParses; private List candidateParses; + private List similarSolvedQuery; public enum ParseState { COMPLETED, diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/SolvedQueryRecallResp.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/SolvedQueryRecallResp.java new file mode 100644 index 000000000..e92dd1cb1 --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/SolvedQueryRecallResp.java @@ -0,0 +1,17 @@ +package com.tencent.supersonic.chat.api.pojo.response; + + +import lombok.Builder; +import lombok.Data; + +@Data +@Builder +public class SolvedQueryRecallResp { + + private Long queryId; + + private Integer parseId; + + private String queryText; + +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java index df3bdfc0f..5725ed4e3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java @@ -23,4 +23,10 @@ public class EmbeddingConfig { @Value("${embedding.nResult:1}") private String nResult; + @Value("${embedding.solvedQuery.recall.path:/solved_query_retrival}") + private String solvedQueryRecallPath; + + @Value("${embedding.solvedQuery.add.path:/solved_query_add}") + private String solvedQueryAddPath; + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/RecallRetrieval.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/RecallRetrieval.java index 4d5470e4f..3a61970e9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/RecallRetrieval.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/RecallRetrieval.java @@ -14,4 +14,6 @@ public class RecallRetrieval { private String presetId; + private String query; + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java new file mode 100644 index 000000000..5357363fe --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java @@ -0,0 +1,141 @@ +package com.tencent.supersonic.chat.queryresponder; + +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; +import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; +import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig; +import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp; +import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; +import org.apache.logging.log4j.util.Strings; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Component; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.util.UriComponentsBuilder; +import java.net.URI; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +@Slf4j +@Component +public class DefaultQueryResponder implements QueryResponder{ + + + private EmbeddingConfig embeddingConfig; + + private RestTemplate restTemplate; + + public DefaultQueryResponder(EmbeddingConfig embeddingConfig, + RestTemplate restTemplate) { + this.embeddingConfig = embeddingConfig; + this.restTemplate = restTemplate; + } + + @Override + public void saveSolvedQuery(String queryText, Long queryId, Integer parseId) { + try { + String uniqueId = generateUniqueId(queryId, parseId); + Map requestMap = new HashMap<>(); + requestMap.put("query", queryText); + requestMap.put("query_id", uniqueId); + doRequest(embeddingConfig.getSolvedQueryAddPath(), + JSONObject.toJSONString(Lists.newArrayList(requestMap))); + } catch (Exception e) { + log.warn("save history question to embedding failed, queryText:{}", queryText, e); + } + } + + @Override + public List recallSolvedQuery(String queryText) { + List solvedQueryRecallResps = Lists.newArrayList(); + try { + String url = embeddingConfig.getUrl() + embeddingConfig.getSolvedQueryRecallPath() + "?n_results=" + + embeddingConfig.getNResult(); + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setLocation(URI.create(url)); + URI requestUrl = UriComponentsBuilder + .fromHttpUrl(url).build().encode().toUri(); + String jsonBody = JSONObject.toJSONString(Lists.newArrayList(queryText)); + HttpEntity entity = new HttpEntity<>(jsonBody, headers); + log.info("[embedding] request body:{}, url:{}", jsonBody, url); + ResponseEntity> embeddingResponseEntity = + restTemplate.exchange(requestUrl, HttpMethod.POST, entity, + new ParameterizedTypeReference>() { + }); + log.info("[embedding] recognize result body:{}", embeddingResponseEntity); + List embeddingResps = embeddingResponseEntity.getBody(); + Set querySet = new HashSet<>(); + if (CollectionUtils.isNotEmpty(embeddingResps)) { + for (EmbeddingResp embeddingResp : embeddingResps) { + List embeddingRetrievals = embeddingResp.getRetrieval(); + for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) { + if (querySet.contains(embeddingRetrieval.getQuery())) { + continue; + } + String id = embeddingRetrieval.getId(); + SolvedQueryRecallResp solvedQueryRecallResp = + SolvedQueryRecallResp.builder() + .queryText(embeddingRetrieval.getQuery()) + .queryId(getQueryId(id)).parseId(getParseId(id)) + .build(); + solvedQueryRecallResps.add(solvedQueryRecallResp); + querySet.add(embeddingRetrieval.getQuery()); + } + } + } + } catch (Exception e) { + log.warn("recall similar solved query failed", e); + } + return solvedQueryRecallResps; + } + + private String generateUniqueId(Long queryId, Integer parseId) { + String uniqueId = queryId + String.valueOf(parseId); + if (parseId < 10) { + uniqueId = queryId + String.format("0%s", parseId); + } + return uniqueId; + } + + private Long getQueryId(String uniqueId) { + return Long.parseLong(uniqueId) / 100; + } + + private Integer getParseId(String uniqueId) { + return Integer.parseInt(uniqueId) % 100; + } + + private ResponseEntity doRequest(String path, String jsonBody) { + if (Strings.isEmpty(embeddingConfig.getUrl())) { + return ResponseEntity.of(Optional.empty()); + } + String url = embeddingConfig.getUrl() + path; + try { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setLocation(URI.create(url)); + URI requestUrl = UriComponentsBuilder + .fromHttpUrl(url).build().encode().toUri(); + HttpEntity entity = new HttpEntity<>(jsonBody, headers); + log.info("[embedding] request body :{}, url:{}", jsonBody, url); + ResponseEntity responseEntity = restTemplate.exchange(requestUrl, + HttpMethod.POST, entity, new ParameterizedTypeReference() {}); + log.info("[embedding] result body:{}", responseEntity); + return responseEntity; + } catch (Exception e) { + log.warn("connect to embedding service failed, url:{}", url); + } + return ResponseEntity.of(Optional.empty()); + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/QueryResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/QueryResponder.java new file mode 100644 index 000000000..2f154bd44 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/QueryResponder.java @@ -0,0 +1,12 @@ +package com.tencent.supersonic.chat.queryresponder; + +import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; +import java.util.List; + +public interface QueryResponder { + + void saveSolvedQuery(String queryText, Long queryId, Integer parseId); + + List recallSolvedQuery(String queryText); + +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 0e1a1e59f..985715350 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -15,12 +15,14 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; +import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.CostType; import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO; import com.tencent.supersonic.chat.query.QuerySelector; import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq; import com.tencent.supersonic.chat.query.QueryManager; +import com.tencent.supersonic.chat.queryresponder.QueryResponder; import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.service.QueryService; import com.tencent.supersonic.chat.service.SemanticService; @@ -63,6 +65,8 @@ public class QueryServiceImpl implements QueryService { private ChatService chatService; @Autowired private StatisticsService statisticsService; + @Autowired + private QueryResponder queryResponder; private final String entity = "ENTITY"; @@ -129,10 +133,13 @@ public class QueryServiceImpl implements QueryService { saveInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(), queryReq.getUser().getName(), queryReq.getChatId().longValue()); } else { + List solvedQueryRecallResps = + queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText()); parseResult = ParseResp.builder() .chatId(queryReq.getChatId()) .queryText(queryReq.getQueryText()) .state(ParseResp.ParseState.FAILED) + .similarSolvedQuery(solvedQueryRecallResps) .build(); } return parseResult; @@ -171,6 +178,7 @@ public class QueryServiceImpl implements QueryService { chatCtx.setUser(queryReq.getUser().getName()); //chatService.addQuery(queryResult, chatCtx); chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx); + queryResponder.saveSolvedQuery(queryReq.getQueryText(), queryReq.getQueryId(), queryReq.getParseId()); } else { chatService.deleteChatQuery(queryReq.getQueryId()); }