mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 20:25:12 +00:00
(improvement)(chat) add QueryResponder to recall history similar solved query
This commit is contained in:
@@ -21,6 +21,7 @@ public class ParseResp {
|
|||||||
private ParseState state;
|
private ParseState state;
|
||||||
private List<SemanticParseInfo> selectedParses;
|
private List<SemanticParseInfo> selectedParses;
|
||||||
private List<SemanticParseInfo> candidateParses;
|
private List<SemanticParseInfo> candidateParses;
|
||||||
|
private List<SolvedQueryRecallResp> similarSolvedQuery;
|
||||||
|
|
||||||
public enum ParseState {
|
public enum ParseState {
|
||||||
COMPLETED,
|
COMPLETED,
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package com.tencent.supersonic.chat.api.pojo.response;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@Builder
|
||||||
|
public class SolvedQueryRecallResp {
|
||||||
|
|
||||||
|
private Long queryId;
|
||||||
|
|
||||||
|
private Integer parseId;
|
||||||
|
|
||||||
|
private String queryText;
|
||||||
|
|
||||||
|
}
|
||||||
@@ -23,4 +23,10 @@ public class EmbeddingConfig {
|
|||||||
@Value("${embedding.nResult:1}")
|
@Value("${embedding.nResult:1}")
|
||||||
private String nResult;
|
private String nResult;
|
||||||
|
|
||||||
|
@Value("${embedding.solvedQuery.recall.path:/solved_query_retrival}")
|
||||||
|
private String solvedQueryRecallPath;
|
||||||
|
|
||||||
|
@Value("${embedding.solvedQuery.add.path:/solved_query_add}")
|
||||||
|
private String solvedQueryAddPath;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,4 +14,6 @@ public class RecallRetrieval {
|
|||||||
|
|
||||||
private String presetId;
|
private String presetId;
|
||||||
|
|
||||||
|
private String query;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,141 @@
|
|||||||
|
package com.tencent.supersonic.chat.queryresponder;
|
||||||
|
|
||||||
|
import com.alibaba.fastjson.JSONObject;
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||||
|
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
|
||||||
|
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
|
||||||
|
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.logging.log4j.util.Strings;
|
||||||
|
import org.springframework.core.ParameterizedTypeReference;
|
||||||
|
import org.springframework.http.HttpEntity;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
|
import org.springframework.http.HttpMethod;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
import org.springframework.web.util.UriComponentsBuilder;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Component
|
||||||
|
public class DefaultQueryResponder implements QueryResponder{
|
||||||
|
|
||||||
|
|
||||||
|
private EmbeddingConfig embeddingConfig;
|
||||||
|
|
||||||
|
private RestTemplate restTemplate;
|
||||||
|
|
||||||
|
public DefaultQueryResponder(EmbeddingConfig embeddingConfig,
|
||||||
|
RestTemplate restTemplate) {
|
||||||
|
this.embeddingConfig = embeddingConfig;
|
||||||
|
this.restTemplate = restTemplate;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void saveSolvedQuery(String queryText, Long queryId, Integer parseId) {
|
||||||
|
try {
|
||||||
|
String uniqueId = generateUniqueId(queryId, parseId);
|
||||||
|
Map<String, String> requestMap = new HashMap<>();
|
||||||
|
requestMap.put("query", queryText);
|
||||||
|
requestMap.put("query_id", uniqueId);
|
||||||
|
doRequest(embeddingConfig.getSolvedQueryAddPath(),
|
||||||
|
JSONObject.toJSONString(Lists.newArrayList(requestMap)));
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SolvedQueryRecallResp> recallSolvedQuery(String queryText) {
|
||||||
|
List<SolvedQueryRecallResp> solvedQueryRecallResps = Lists.newArrayList();
|
||||||
|
try {
|
||||||
|
String url = embeddingConfig.getUrl() + embeddingConfig.getSolvedQueryRecallPath() + "?n_results="
|
||||||
|
+ embeddingConfig.getNResult();
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
headers.setLocation(URI.create(url));
|
||||||
|
URI requestUrl = UriComponentsBuilder
|
||||||
|
.fromHttpUrl(url).build().encode().toUri();
|
||||||
|
String jsonBody = JSONObject.toJSONString(Lists.newArrayList(queryText));
|
||||||
|
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
|
||||||
|
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
|
||||||
|
ResponseEntity<List<EmbeddingResp>> embeddingResponseEntity =
|
||||||
|
restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
||||||
|
new ParameterizedTypeReference<List<EmbeddingResp>>() {
|
||||||
|
});
|
||||||
|
log.info("[embedding] recognize result body:{}", embeddingResponseEntity);
|
||||||
|
List<EmbeddingResp> embeddingResps = embeddingResponseEntity.getBody();
|
||||||
|
Set<String> querySet = new HashSet<>();
|
||||||
|
if (CollectionUtils.isNotEmpty(embeddingResps)) {
|
||||||
|
for (EmbeddingResp embeddingResp : embeddingResps) {
|
||||||
|
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||||
|
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
|
||||||
|
if (querySet.contains(embeddingRetrieval.getQuery())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String id = embeddingRetrieval.getId();
|
||||||
|
SolvedQueryRecallResp solvedQueryRecallResp =
|
||||||
|
SolvedQueryRecallResp.builder()
|
||||||
|
.queryText(embeddingRetrieval.getQuery())
|
||||||
|
.queryId(getQueryId(id)).parseId(getParseId(id))
|
||||||
|
.build();
|
||||||
|
solvedQueryRecallResps.add(solvedQueryRecallResp);
|
||||||
|
querySet.add(embeddingRetrieval.getQuery());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("recall similar solved query failed", e);
|
||||||
|
}
|
||||||
|
return solvedQueryRecallResps;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String generateUniqueId(Long queryId, Integer parseId) {
|
||||||
|
String uniqueId = queryId + String.valueOf(parseId);
|
||||||
|
if (parseId < 10) {
|
||||||
|
uniqueId = queryId + String.format("0%s", parseId);
|
||||||
|
}
|
||||||
|
return uniqueId;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Long getQueryId(String uniqueId) {
|
||||||
|
return Long.parseLong(uniqueId) / 100;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Integer getParseId(String uniqueId) {
|
||||||
|
return Integer.parseInt(uniqueId) % 100;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ResponseEntity<String> doRequest(String path, String jsonBody) {
|
||||||
|
if (Strings.isEmpty(embeddingConfig.getUrl())) {
|
||||||
|
return ResponseEntity.of(Optional.empty());
|
||||||
|
}
|
||||||
|
String url = embeddingConfig.getUrl() + path;
|
||||||
|
try {
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
headers.setLocation(URI.create(url));
|
||||||
|
URI requestUrl = UriComponentsBuilder
|
||||||
|
.fromHttpUrl(url).build().encode().toUri();
|
||||||
|
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
|
||||||
|
log.info("[embedding] request body :{}, url:{}", jsonBody, url);
|
||||||
|
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
|
||||||
|
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {});
|
||||||
|
log.info("[embedding] result body:{}", responseEntity);
|
||||||
|
return responseEntity;
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("connect to embedding service failed, url:{}", url);
|
||||||
|
}
|
||||||
|
return ResponseEntity.of(Optional.empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package com.tencent.supersonic.chat.queryresponder;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface QueryResponder {
|
||||||
|
|
||||||
|
void saveSolvedQuery(String queryText, Long queryId, Integer parseId);
|
||||||
|
|
||||||
|
List<SolvedQueryRecallResp> recallSolvedQuery(String queryText);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -15,12 +15,14 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
|
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
|
||||||
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
|
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
|
||||||
import com.tencent.supersonic.chat.query.QuerySelector;
|
import com.tencent.supersonic.chat.query.QuerySelector;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
|
||||||
import com.tencent.supersonic.chat.query.QueryManager;
|
import com.tencent.supersonic.chat.query.QueryManager;
|
||||||
|
import com.tencent.supersonic.chat.queryresponder.QueryResponder;
|
||||||
import com.tencent.supersonic.chat.service.ChatService;
|
import com.tencent.supersonic.chat.service.ChatService;
|
||||||
import com.tencent.supersonic.chat.service.QueryService;
|
import com.tencent.supersonic.chat.service.QueryService;
|
||||||
import com.tencent.supersonic.chat.service.SemanticService;
|
import com.tencent.supersonic.chat.service.SemanticService;
|
||||||
@@ -63,6 +65,8 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
private ChatService chatService;
|
private ChatService chatService;
|
||||||
@Autowired
|
@Autowired
|
||||||
private StatisticsService statisticsService;
|
private StatisticsService statisticsService;
|
||||||
|
@Autowired
|
||||||
|
private QueryResponder queryResponder;
|
||||||
|
|
||||||
private final String entity = "ENTITY";
|
private final String entity = "ENTITY";
|
||||||
|
|
||||||
@@ -129,10 +133,13 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
saveInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(),
|
saveInfo(timeCostDOList, queryReq.getQueryText(), parseResult.getQueryId(),
|
||||||
queryReq.getUser().getName(), queryReq.getChatId().longValue());
|
queryReq.getUser().getName(), queryReq.getChatId().longValue());
|
||||||
} else {
|
} else {
|
||||||
|
List<SolvedQueryRecallResp> solvedQueryRecallResps =
|
||||||
|
queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText());
|
||||||
parseResult = ParseResp.builder()
|
parseResult = ParseResp.builder()
|
||||||
.chatId(queryReq.getChatId())
|
.chatId(queryReq.getChatId())
|
||||||
.queryText(queryReq.getQueryText())
|
.queryText(queryReq.getQueryText())
|
||||||
.state(ParseResp.ParseState.FAILED)
|
.state(ParseResp.ParseState.FAILED)
|
||||||
|
.similarSolvedQuery(solvedQueryRecallResps)
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
return parseResult;
|
return parseResult;
|
||||||
@@ -171,6 +178,7 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
chatCtx.setUser(queryReq.getUser().getName());
|
chatCtx.setUser(queryReq.getUser().getName());
|
||||||
//chatService.addQuery(queryResult, chatCtx);
|
//chatService.addQuery(queryResult, chatCtx);
|
||||||
chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx);
|
chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx);
|
||||||
|
queryResponder.saveSolvedQuery(queryReq.getQueryText(), queryReq.getQueryId(), queryReq.getParseId());
|
||||||
} else {
|
} else {
|
||||||
chatService.deleteChatQuery(queryReq.getQueryId());
|
chatService.deleteChatQuery(queryReq.getQueryId());
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user