mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
Merge pull request #142 from lxwcodemonkey/feature/lxw
(improvement)(chat) add QueryResponder to recall history similar solved query
This commit is contained in:
@@ -21,6 +21,7 @@ public class ParseResp {
|
||||
private ParseState state;
|
||||
private List<SemanticParseInfo> selectedParses;
|
||||
private List<SemanticParseInfo> candidateParses;
|
||||
private List<SolvedQueryRecallResp> similarSolvedQuery;
|
||||
|
||||
public enum ParseState {
|
||||
COMPLETED,
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public class SolvedQueryRecallResp {
|
||||
|
||||
private Long queryId;
|
||||
|
||||
private Integer parseId;
|
||||
|
||||
private String queryText;
|
||||
|
||||
}
|
||||
@@ -23,4 +23,13 @@ public class EmbeddingConfig {
|
||||
@Value("${embedding.nResult:1}")
|
||||
private String nResult;
|
||||
|
||||
@Value("${embedding.solvedQuery.recall.path:/solved_query_retrival}")
|
||||
private String solvedQueryRecallPath;
|
||||
|
||||
@Value("${embedding.solvedQuery.add.path:/solved_query_add}")
|
||||
private String solvedQueryAddPath;
|
||||
|
||||
@Value("${embedding.solved.query.nResult:5}")
|
||||
private String solvedQueryResultNum;
|
||||
|
||||
}
|
||||
|
||||
@@ -14,4 +14,6 @@ public class RecallRetrieval {
|
||||
|
||||
private String presetId;
|
||||
|
||||
private String query;
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
package com.tencent.supersonic.chat.queryresponder;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class DefaultQueryResponder implements QueryResponder {
|
||||
|
||||
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
private RestTemplate restTemplate;
|
||||
|
||||
public DefaultQueryResponder(EmbeddingConfig embeddingConfig,
|
||||
RestTemplate restTemplate) {
|
||||
this.embeddingConfig = embeddingConfig;
|
||||
this.restTemplate = restTemplate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void saveSolvedQuery(String queryText, Long queryId, Integer parseId) {
|
||||
try {
|
||||
String uniqueId = generateUniqueId(queryId, parseId);
|
||||
Map<String, String> requestMap = new HashMap<>();
|
||||
requestMap.put("query", queryText);
|
||||
requestMap.put("query_id", uniqueId);
|
||||
doRequest(embeddingConfig.getSolvedQueryAddPath(),
|
||||
JSONObject.toJSONString(Lists.newArrayList(requestMap)));
|
||||
} catch (Exception e) {
|
||||
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SolvedQueryRecallResp> recallSolvedQuery(String queryText) {
|
||||
List<SolvedQueryRecallResp> solvedQueryRecallResps = Lists.newArrayList();
|
||||
try {
|
||||
String url = embeddingConfig.getUrl() + embeddingConfig.getSolvedQueryRecallPath() + "?n_results="
|
||||
+ embeddingConfig.getSolvedQueryResultNum();
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
headers.setLocation(URI.create(url));
|
||||
URI requestUrl = UriComponentsBuilder
|
||||
.fromHttpUrl(url).build().encode().toUri();
|
||||
String jsonBody = JSONObject.toJSONString(Lists.newArrayList(queryText));
|
||||
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
|
||||
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
|
||||
ResponseEntity<List<EmbeddingResp>> embeddingResponseEntity =
|
||||
restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
||||
new ParameterizedTypeReference<List<EmbeddingResp>>() {
|
||||
});
|
||||
log.info("[embedding] recognize result body:{}", embeddingResponseEntity);
|
||||
List<EmbeddingResp> embeddingResps = embeddingResponseEntity.getBody();
|
||||
Set<String> querySet = new HashSet<>();
|
||||
if (CollectionUtils.isNotEmpty(embeddingResps)) {
|
||||
for (EmbeddingResp embeddingResp : embeddingResps) {
|
||||
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
if (querySet.contains(embeddingRetrieval.getQuery())) {
|
||||
continue;
|
||||
}
|
||||
String id = embeddingRetrieval.getId();
|
||||
SolvedQueryRecallResp solvedQueryRecallResp =
|
||||
SolvedQueryRecallResp.builder()
|
||||
.queryText(embeddingRetrieval.getQuery())
|
||||
.queryId(getQueryId(id)).parseId(getParseId(id))
|
||||
.build();
|
||||
solvedQueryRecallResps.add(solvedQueryRecallResp);
|
||||
querySet.add(embeddingRetrieval.getQuery());
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.warn("recall similar solved query failed, queryText:{}", queryText);
|
||||
}
|
||||
return solvedQueryRecallResps;
|
||||
}
|
||||
|
||||
private String generateUniqueId(Long queryId, Integer parseId) {
|
||||
String uniqueId = queryId + String.valueOf(parseId);
|
||||
if (parseId < 10) {
|
||||
uniqueId = queryId + String.format("0%s", parseId);
|
||||
}
|
||||
return uniqueId;
|
||||
}
|
||||
|
||||
private Long getQueryId(String uniqueId) {
|
||||
return Long.parseLong(uniqueId) / 100;
|
||||
}
|
||||
|
||||
private Integer getParseId(String uniqueId) {
|
||||
return Integer.parseInt(uniqueId) % 100;
|
||||
}
|
||||
|
||||
private ResponseEntity<String> doRequest(String path, String jsonBody) {
|
||||
if (Strings.isEmpty(embeddingConfig.getUrl())) {
|
||||
return ResponseEntity.of(Optional.empty());
|
||||
}
|
||||
String url = embeddingConfig.getUrl() + path;
|
||||
try {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
headers.setLocation(URI.create(url));
|
||||
URI requestUrl = UriComponentsBuilder
|
||||
.fromHttpUrl(url).build().encode().toUri();
|
||||
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
|
||||
log.info("[embedding] request body :{}, url:{}", jsonBody, url);
|
||||
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
|
||||
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {});
|
||||
log.info("[embedding] result body:{}", responseEntity);
|
||||
return responseEntity;
|
||||
} catch (Exception e) {
|
||||
log.warn("connect to embedding service failed, url:{}", url);
|
||||
}
|
||||
return ResponseEntity.of(Optional.empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.tencent.supersonic.chat.queryresponder;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||
import java.util.List;
|
||||
|
||||
public interface QueryResponder {
|
||||
|
||||
void saveSolvedQuery(String queryText, Long queryId, Integer parseId);
|
||||
|
||||
List<SolvedQueryRecallResp> recallSolvedQuery(String queryText);
|
||||
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.CostType;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO;
|
||||
@@ -26,6 +27,7 @@ import com.tencent.supersonic.chat.query.QuerySelector;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.DslQuery;
|
||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMResp;
|
||||
import com.tencent.supersonic.chat.queryresponder.QueryResponder;
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.chat.service.SemanticService;
|
||||
@@ -71,6 +73,8 @@ public class QueryServiceImpl implements QueryService {
|
||||
private ChatService chatService;
|
||||
@Autowired
|
||||
private StatisticsService statisticsService;
|
||||
@Autowired
|
||||
private QueryResponder queryResponder;
|
||||
|
||||
@Value("${time.threshold: 100}")
|
||||
private Integer timeThreshold;
|
||||
@@ -145,6 +149,9 @@ public class QueryServiceImpl implements QueryService {
|
||||
.state(ParseResp.ParseState.FAILED)
|
||||
.build();
|
||||
}
|
||||
List<SolvedQueryRecallResp> solvedQueryRecallResps =
|
||||
queryResponder.recallSolvedQuery(queryCtx.getRequest().getQueryText());
|
||||
parseResult.setSimilarSolvedQuery(solvedQueryRecallResps);
|
||||
return parseResult;
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
<mapper namespace="com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper">
|
||||
<resultMap id="BaseResultMap" type="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO">
|
||||
<id column="question_id" jdbcType="BIGINT" property="questionId" />
|
||||
<result column="agent_id" jdbcType="INTEGER" property="agentId" />
|
||||
<result column="create_time" jdbcType="TIMESTAMP" property="createTime" />
|
||||
<result column="user_name" jdbcType="VARCHAR" property="userName" />
|
||||
<result column="query_state" jdbcType="INTEGER" property="queryState" />
|
||||
@@ -44,7 +45,7 @@
|
||||
</where>
|
||||
</sql>
|
||||
<sql id="Base_Column_List">
|
||||
question_id, create_time, user_name, query_state, chat_id, score, feedback
|
||||
question_id, agent_id, create_time, user_name, query_state, chat_id, score, feedback
|
||||
</sql>
|
||||
<sql id="Blob_Column_List">
|
||||
query_text, query_result
|
||||
@@ -65,142 +66,23 @@
|
||||
order by ${orderByClause}
|
||||
</if>
|
||||
</select>
|
||||
<select id="selectByExample" parameterType="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample" resultMap="BaseResultMap">
|
||||
select
|
||||
<if test="distinct">
|
||||
distinct
|
||||
</if>
|
||||
<include refid="Base_Column_List" />
|
||||
from s2_chat_query
|
||||
<if test="_parameter != null">
|
||||
<include refid="Example_Where_Clause" />
|
||||
</if>
|
||||
<if test="orderByClause != null">
|
||||
order by ${orderByClause}
|
||||
</if>
|
||||
<if test="limitStart != null and limitStart>=0">
|
||||
limit #{limitStart} , #{limitEnd}
|
||||
</if>
|
||||
</select>
|
||||
<select id="selectByPrimaryKey" parameterType="java.lang.Long" resultMap="ResultMapWithBLOBs">
|
||||
select
|
||||
<include refid="Base_Column_List" />
|
||||
,
|
||||
<include refid="Blob_Column_List" />
|
||||
from s2_chat_query
|
||||
where question_id = #{questionId,jdbcType=BIGINT}
|
||||
</select>
|
||||
|
||||
|
||||
<delete id="deleteByPrimaryKey" parameterType="java.lang.Long">
|
||||
delete from s2_chat_query
|
||||
where question_id = #{questionId,jdbcType=BIGINT}
|
||||
</delete>
|
||||
<insert id="insert" parameterType="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO">
|
||||
insert into s2_chat_query (question_id, create_time, user_name,
|
||||
insert into s2_chat_query (question_id, agent_id, create_time, user_name,
|
||||
query_state, chat_id, score,
|
||||
feedback, query_text, query_result
|
||||
)
|
||||
values (#{questionId,jdbcType=BIGINT}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR},
|
||||
values (#{questionId,jdbcType=BIGINT}, #{agentId,jdbcType=INTEGER}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR},
|
||||
#{queryState,jdbcType=INTEGER}, #{chatId,jdbcType=BIGINT}, #{score,jdbcType=INTEGER},
|
||||
#{feedback,jdbcType=VARCHAR}, #{queryText,jdbcType=LONGVARCHAR}, #{queryResult,jdbcType=LONGVARCHAR}
|
||||
)
|
||||
</insert>
|
||||
<insert id="insertSelective" parameterType="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO">
|
||||
insert into s2_chat_query
|
||||
<trim prefix="(" suffix=")" suffixOverrides=",">
|
||||
<if test="questionId != null">
|
||||
question_id,
|
||||
</if>
|
||||
<if test="createTime != null">
|
||||
create_time,
|
||||
</if>
|
||||
<if test="userName != null">
|
||||
user_name,
|
||||
</if>
|
||||
<if test="queryState != null">
|
||||
query_state,
|
||||
</if>
|
||||
<if test="chatId != null">
|
||||
chat_id,
|
||||
</if>
|
||||
<if test="score != null">
|
||||
score,
|
||||
</if>
|
||||
<if test="feedback != null">
|
||||
feedback,
|
||||
</if>
|
||||
<if test="queryText != null">
|
||||
query_text,
|
||||
</if>
|
||||
<if test="queryResult != null">
|
||||
query_result,
|
||||
</if>
|
||||
</trim>
|
||||
<trim prefix="values (" suffix=")" suffixOverrides=",">
|
||||
<if test="questionId != null">
|
||||
#{questionId,jdbcType=BIGINT},
|
||||
</if>
|
||||
<if test="createTime != null">
|
||||
#{createTime,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="userName != null">
|
||||
#{userName,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="queryState != null">
|
||||
#{queryState,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="chatId != null">
|
||||
#{chatId,jdbcType=BIGINT},
|
||||
</if>
|
||||
<if test="score != null">
|
||||
#{score,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="feedback != null">
|
||||
#{feedback,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="queryText != null">
|
||||
#{queryText,jdbcType=LONGVARCHAR},
|
||||
</if>
|
||||
<if test="queryResult != null">
|
||||
#{queryResult,jdbcType=LONGVARCHAR},
|
||||
</if>
|
||||
</trim>
|
||||
</insert>
|
||||
<select id="countByExample" parameterType="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample" resultType="java.lang.Long">
|
||||
select count(*) from s2_chat_query
|
||||
<if test="_parameter != null">
|
||||
<include refid="Example_Where_Clause" />
|
||||
</if>
|
||||
</select>
|
||||
<update id="updateByPrimaryKeySelective" parameterType="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO">
|
||||
update s2_chat_query
|
||||
<set>
|
||||
<if test="createTime != null">
|
||||
create_time = #{createTime,jdbcType=TIMESTAMP},
|
||||
</if>
|
||||
<if test="userName != null">
|
||||
user_name = #{userName,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="queryState != null">
|
||||
query_state = #{queryState,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="chatId != null">
|
||||
chat_id = #{chatId,jdbcType=BIGINT},
|
||||
</if>
|
||||
<if test="score != null">
|
||||
score = #{score,jdbcType=INTEGER},
|
||||
</if>
|
||||
<if test="feedback != null">
|
||||
feedback = #{feedback,jdbcType=VARCHAR},
|
||||
</if>
|
||||
<if test="queryText != null">
|
||||
query_text = #{queryText,jdbcType=LONGVARCHAR},
|
||||
</if>
|
||||
<if test="queryResult != null">
|
||||
query_result = #{queryResult,jdbcType=LONGVARCHAR},
|
||||
</if>
|
||||
</set>
|
||||
where question_id = #{questionId,jdbcType=BIGINT}
|
||||
</update>
|
||||
|
||||
<update id="updateByPrimaryKeyWithBLOBs" parameterType="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO">
|
||||
update s2_chat_query
|
||||
<set>
|
||||
@@ -231,14 +113,4 @@
|
||||
</set>
|
||||
where question_id = #{questionId,jdbcType=BIGINT}
|
||||
</update>
|
||||
<update id="updateByPrimaryKey" parameterType="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO">
|
||||
update s2_chat_query
|
||||
set create_time = #{createTime,jdbcType=TIMESTAMP},
|
||||
user_name = #{userName,jdbcType=VARCHAR},
|
||||
query_state = #{queryState,jdbcType=INTEGER},
|
||||
chat_id = #{chatId,jdbcType=BIGINT},
|
||||
score = #{score,jdbcType=INTEGER},
|
||||
feedback = #{feedback,jdbcType=VARCHAR}
|
||||
where question_id = #{questionId,jdbcType=BIGINT}
|
||||
</update>
|
||||
</mapper>
|
||||
|
||||
@@ -59,7 +59,7 @@
|
||||
join (
|
||||
select distinct chat_id
|
||||
from s2_chat_query
|
||||
where query_state = 0 and agent_id = ${agentId}
|
||||
where query_state = 1 and agent_id = ${agentId}
|
||||
order by chat_id desc
|
||||
limit #{start}, #{limit}
|
||||
) q2
|
||||
|
||||
Reference in New Issue
Block a user