diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java index e8fc10db8..9af3946a9 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/ParseResp.java @@ -21,6 +21,7 @@ public class ParseResp { private ParseState state; private List selectedParses; private List candidateParses; + private List similarSolvedQuery; public enum ParseState { COMPLETED, diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/SolvedQueryRecallResp.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/SolvedQueryRecallResp.java new file mode 100644 index 000000000..e92dd1cb1 --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/response/SolvedQueryRecallResp.java @@ -0,0 +1,17 @@ +package com.tencent.supersonic.chat.api.pojo.response; + + +import lombok.Builder; +import lombok.Data; + +@Data +@Builder +public class SolvedQueryRecallResp { + + private Long queryId; + + private Integer parseId; + + private String queryText; + +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java index df3bdfc0f..46ff9f848 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingConfig.java @@ -23,4 +23,13 @@ public class EmbeddingConfig { @Value("${embedding.nResult:1}") private String nResult; + @Value("${embedding.solvedQuery.recall.path:/solved_query_retrival}") + private String solvedQueryRecallPath; + + @Value("${embedding.solvedQuery.add.path:/solved_query_add}") + private String solvedQueryAddPath; + + @Value("${embedding.solved.query.nResult:5}") + private String solvedQueryResultNum; + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/RecallRetrieval.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/RecallRetrieval.java index 4d5470e4f..3a61970e9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/RecallRetrieval.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/RecallRetrieval.java @@ -14,4 +14,6 @@ public class RecallRetrieval { private String presetId; + private String query; + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java new file mode 100644 index 000000000..1c8030db0 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/DefaultQueryResponder.java @@ -0,0 +1,141 @@ +package com.tencent.supersonic.chat.queryresponder; + +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; +import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; +import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig; +import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp; +import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; +import org.apache.logging.log4j.util.Strings; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Component; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.util.UriComponentsBuilder; +import java.net.URI; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +@Slf4j +@Component +public class DefaultQueryResponder implements QueryResponder { + + + private EmbeddingConfig embeddingConfig; + + private RestTemplate restTemplate; + + public DefaultQueryResponder(EmbeddingConfig embeddingConfig, + RestTemplate restTemplate) { + this.embeddingConfig = embeddingConfig; + this.restTemplate = restTemplate; + } + + @Override + public void saveSolvedQuery(String queryText, Long queryId, Integer parseId) { + try { + String uniqueId = generateUniqueId(queryId, parseId); + Map requestMap = new HashMap<>(); + requestMap.put("query", queryText); + requestMap.put("query_id", uniqueId); + doRequest(embeddingConfig.getSolvedQueryAddPath(), + JSONObject.toJSONString(Lists.newArrayList(requestMap))); + } catch (Exception e) { + log.warn("save history question to embedding failed, queryText:{}", queryText, e); + } + } + + @Override + public List recallSolvedQuery(String queryText) { + List solvedQueryRecallResps = Lists.newArrayList(); + try { + String url = embeddingConfig.getUrl() + embeddingConfig.getSolvedQueryRecallPath() + "?n_results=" + + embeddingConfig.getSolvedQueryResultNum(); + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setLocation(URI.create(url)); + URI requestUrl = UriComponentsBuilder + .fromHttpUrl(url).build().encode().toUri(); + String jsonBody = JSONObject.toJSONString(Lists.newArrayList(queryText)); + HttpEntity entity = new HttpEntity<>(jsonBody, headers); + log.info("[embedding] request body:{}, url:{}", jsonBody, url); + ResponseEntity> embeddingResponseEntity = + restTemplate.exchange(requestUrl, HttpMethod.POST, entity, + new ParameterizedTypeReference>() { + }); + log.info("[embedding] recognize result body:{}", embeddingResponseEntity); + List embeddingResps = embeddingResponseEntity.getBody(); + Set querySet = new HashSet<>(); + if (CollectionUtils.isNotEmpty(embeddingResps)) { + for (EmbeddingResp embeddingResp : embeddingResps) { + List embeddingRetrievals = embeddingResp.getRetrieval(); + for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) { + if (querySet.contains(embeddingRetrieval.getQuery())) { + continue; + } + String id = embeddingRetrieval.getId(); + SolvedQueryRecallResp solvedQueryRecallResp = + SolvedQueryRecallResp.builder() + .queryText(embeddingRetrieval.getQuery()) + .queryId(getQueryId(id)).parseId(getParseId(id)) + .build(); + solvedQueryRecallResps.add(solvedQueryRecallResp); + querySet.add(embeddingRetrieval.getQuery()); + } + } + } + } catch (Exception e) { + log.warn("recall similar solved query failed, queryText:{}", queryText); + } + return solvedQueryRecallResps; + } + + private String generateUniqueId(Long queryId, Integer parseId) { + String uniqueId = queryId + String.valueOf(parseId); + if (parseId < 10) { + uniqueId = queryId + String.format("0%s", parseId); + } + return uniqueId; + } + + private Long getQueryId(String uniqueId) { + return Long.parseLong(uniqueId) / 100; + } + + private Integer getParseId(String uniqueId) { + return Integer.parseInt(uniqueId) % 100; + } + + private ResponseEntity doRequest(String path, String jsonBody) { + if (Strings.isEmpty(embeddingConfig.getUrl())) { + return ResponseEntity.of(Optional.empty()); + } + String url = embeddingConfig.getUrl() + path; + try { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.setLocation(URI.create(url)); + URI requestUrl = UriComponentsBuilder + .fromHttpUrl(url).build().encode().toUri(); + HttpEntity entity = new HttpEntity<>(jsonBody, headers); + log.info("[embedding] request body :{}, url:{}", jsonBody, url); + ResponseEntity responseEntity = restTemplate.exchange(requestUrl, + HttpMethod.POST, entity, new ParameterizedTypeReference() {}); + log.info("[embedding] result body:{}", responseEntity); + return responseEntity; + } catch (Exception e) { + log.warn("connect to embedding service failed, url:{}", url); + } + return ResponseEntity.of(Optional.empty()); + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/QueryResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/QueryResponder.java new file mode 100644 index 000000000..2f154bd44 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/queryresponder/QueryResponder.java @@ -0,0 +1,12 @@ +package com.tencent.supersonic.chat.queryresponder; + +import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; +import java.util.List; + +public interface QueryResponder { + + void saveSolvedQuery(String queryText, Long queryId, Integer parseId); + + List recallSolvedQuery(String queryText); + +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index 63a054454..4f4d8f0f4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -19,6 +19,7 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult; +import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.CostType; import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO; @@ -26,6 +27,7 @@ import com.tencent.supersonic.chat.query.QuerySelector; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.llm.dsl.DslQuery; import com.tencent.supersonic.chat.query.llm.dsl.LLMResp; +import com.tencent.supersonic.chat.queryresponder.QueryResponder; import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.service.QueryService; import com.tencent.supersonic.chat.service.SemanticService; @@ -71,6 +73,8 @@ public class QueryServiceImpl implements QueryService { private ChatService chatService; @Autowired private StatisticsService statisticsService; + @Autowired + private QueryResponder queryResponder; @Value("${time.threshold: 100}") private Integer timeThreshold; @@ -145,6 +149,9 @@ public class QueryServiceImpl implements QueryService { .state(ParseResp.ParseState.FAILED) .build(); } + List solvedQueryRecallResps = + queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText()); + parseResult.setSimilarSolvedQuery(solvedQueryRecallResps); return parseResult; } diff --git a/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml b/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml index 72cbbc4b4..1cba2a902 100644 --- a/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml +++ b/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml @@ -3,6 +3,7 @@ + @@ -44,7 +45,7 @@ - question_id, create_time, user_name, query_state, chat_id, score, feedback + question_id, agent_id, create_time, user_name, query_state, chat_id, score, feedback query_text, query_result @@ -65,142 +66,23 @@ order by ${orderByClause} - - + + delete from s2_chat_query where question_id = #{questionId,jdbcType=BIGINT} - insert into s2_chat_query (question_id, create_time, user_name, + insert into s2_chat_query (question_id, agent_id, create_time, user_name, query_state, chat_id, score, feedback, query_text, query_result ) - values (#{questionId,jdbcType=BIGINT}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR}, + values (#{questionId,jdbcType=BIGINT}, #{agentId,jdbcType=INTEGER}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR}, #{queryState,jdbcType=INTEGER}, #{chatId,jdbcType=BIGINT}, #{score,jdbcType=INTEGER}, #{feedback,jdbcType=VARCHAR}, #{queryText,jdbcType=LONGVARCHAR}, #{queryResult,jdbcType=LONGVARCHAR} ) - - insert into s2_chat_query - - - question_id, - - - create_time, - - - user_name, - - - query_state, - - - chat_id, - - - score, - - - feedback, - - - query_text, - - - query_result, - - - - - #{questionId,jdbcType=BIGINT}, - - - #{createTime,jdbcType=TIMESTAMP}, - - - #{userName,jdbcType=VARCHAR}, - - - #{queryState,jdbcType=INTEGER}, - - - #{chatId,jdbcType=BIGINT}, - - - #{score,jdbcType=INTEGER}, - - - #{feedback,jdbcType=VARCHAR}, - - - #{queryText,jdbcType=LONGVARCHAR}, - - - #{queryResult,jdbcType=LONGVARCHAR}, - - - - - - update s2_chat_query - - - create_time = #{createTime,jdbcType=TIMESTAMP}, - - - user_name = #{userName,jdbcType=VARCHAR}, - - - query_state = #{queryState,jdbcType=INTEGER}, - - - chat_id = #{chatId,jdbcType=BIGINT}, - - - score = #{score,jdbcType=INTEGER}, - - - feedback = #{feedback,jdbcType=VARCHAR}, - - - query_text = #{queryText,jdbcType=LONGVARCHAR}, - - - query_result = #{queryResult,jdbcType=LONGVARCHAR}, - - - where question_id = #{questionId,jdbcType=BIGINT} - + update s2_chat_query @@ -231,14 +113,4 @@ where question_id = #{questionId,jdbcType=BIGINT} - - update s2_chat_query - set create_time = #{createTime,jdbcType=TIMESTAMP}, - user_name = #{userName,jdbcType=VARCHAR}, - query_state = #{queryState,jdbcType=INTEGER}, - chat_id = #{chatId,jdbcType=BIGINT}, - score = #{score,jdbcType=INTEGER}, - feedback = #{feedback,jdbcType=VARCHAR} - where question_id = #{questionId,jdbcType=BIGINT} - diff --git a/chat/core/src/main/resources/mapper/custom/ShowCaseCustomMapper.xml b/chat/core/src/main/resources/mapper/custom/ShowCaseCustomMapper.xml index adaf36822..7dcb1d213 100644 --- a/chat/core/src/main/resources/mapper/custom/ShowCaseCustomMapper.xml +++ b/chat/core/src/main/resources/mapper/custom/ShowCaseCustomMapper.xml @@ -59,7 +59,7 @@ join ( select distinct chat_id from s2_chat_query - where query_state = 0 and agent_id = ${agentId} + where query_state = 1 and agent_id = ${agentId} order by chat_id desc limit #{start}, #{limit} ) q2