(improvement)(chat) Add DrillDownDimensionProcessor and SimilarQueryProcessor to obtain recommended dimensions and similar queries (#511)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2023-12-15 10:27:33 +08:00
committed by GitHub
parent 7db1cc270e
commit e9a479e2df
32 changed files with 436 additions and 565 deletions

View File

@@ -10,7 +10,7 @@ import lombok.NoArgsConstructor;
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class SolvedQueryReq {
public class SimilarQueryReq {
private Long queryId;

View File

@@ -6,6 +6,6 @@ import java.util.List;
@Data
public class QueryRecallResp {
private List<SolvedQueryRecallResp> solvedQueryRecallRespList;
private List<SimilarQueryRecallResp> solvedQueryRecallRespList;
private Long queryTimeCost;
}

View File

@@ -1,9 +1,9 @@
package com.tencent.supersonic.chat.api.pojo.response;
import java.util.Date;
import java.util.List;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.Data;
import java.util.Date;
import java.util.List;
@Data
public class QueryResp {
@@ -16,4 +16,5 @@ public class QueryResp {
private String queryText;
private QueryResult queryResult;
private List<SemanticParseInfo> parseInfos;
private List<SimilarQueryRecallResp> similarQueries;
}

View File

@@ -1,11 +1,12 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.pojo.QueryAuthorization;
import com.tencent.supersonic.common.pojo.QueryColumn;
import lombok.Data;
import java.util.List;
import java.util.Map;
import lombok.Data;
@Data
public class QueryResult {
@@ -22,4 +23,5 @@ public class QueryResult {
private Object response;
private List<Map<String, Object>> queryResults;
private Long queryTimeCost;
private List<SchemaElement> recommendedDimensions;
}

View File

@@ -6,7 +6,7 @@ import lombok.Data;
@Data
@Builder
public class SolvedQueryRecallResp {
public class SimilarQueryRecallResp {
private Long queryId;

View File

@@ -1,7 +1,9 @@
package com.tencent.supersonic.chat.persistence.dataobject;
import lombok.Data;
import java.util.Date;
@Data
public class ChatQueryDO {
/**
*/
@@ -43,155 +45,6 @@ public class ChatQueryDO {
*/
private String queryResult;
/**
* @return question_id
*/
public Long getQuestionId() {
return questionId;
}
private String similarQueries;
/**
* @param questionId
*/
public void setQuestionId(Long questionId) {
this.questionId = questionId;
}
/**
* @return agent_id
*/
public Integer getAgentId() {
return agentId;
}
/**
* @param agentId
*/
public void setAgentId(Integer agentId) {
this.agentId = agentId;
}
/**
* @return create_time
*/
public Date getCreateTime() {
return createTime;
}
/**
* @param createTime
*/
public void setCreateTime(Date createTime) {
this.createTime = createTime;
}
/**
* @return user_name
*/
public String getUserName() {
return userName;
}
/**
* @param userName
*/
public void setUserName(String userName) {
this.userName = userName == null ? null : userName.trim();
}
/**
*
* @return query_state
*/
public Integer getQueryState() {
return queryState;
}
/**
*
* @param queryState
*/
public void setQueryState(Integer queryState) {
this.queryState = queryState;
}
/**
*
* @return chat_id
*/
public Long getChatId() {
return chatId;
}
/**
*
* @param chatId
*/
public void setChatId(Long chatId) {
this.chatId = chatId;
}
/**
*
* @return score
*/
public Integer getScore() {
return score;
}
/**
*
* @param score
*/
public void setScore(Integer score) {
this.score = score;
}
/**
*
* @return feedback
*/
public String getFeedback() {
return feedback;
}
/**
*
* @param feedback
*/
public void setFeedback(String feedback) {
this.feedback = feedback == null ? null : feedback.trim();
}
/**
*
* @return query_text
*/
public String getQueryText() {
return queryText;
}
/**
*
* @param queryText
*/
public void setQueryText(String queryText) {
this.queryText = queryText == null ? null : queryText.trim();
}
/**
*
* @return query_result
*/
public String getQueryResult() {
return queryResult;
}
/**
*
* @param queryResult
*/
public void setQueryResult(String queryResult) {
this.queryResult = queryResult == null ? null : queryResult.trim();
}
}

View File

@@ -16,4 +16,6 @@ public interface ChatQueryDOMapper {
int updateByPrimaryKeyWithBLOBs(ChatQueryDO record);
Boolean deleteByPrimaryKey(Long questionId);
ChatQueryDO selectByPrimaryKey(Long questionId);
}

View File

@@ -3,13 +3,12 @@ package com.tencent.supersonic.chat.persistence.repository;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import java.util.List;
@@ -17,9 +16,11 @@ public interface ChatQueryRepository {
PageInfo<QueryResp> getChatQuery(PageQueryInfoReq pageQueryInfoCommend, Long chatId);
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
QueryResp getChatQuery(Long queryId);
void createChatQuery(QueryResult queryResult, ChatContext chatCtx);
ChatQueryDO getChatQueryDO(Long queryId);
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
void updateChatParseInfo(List<ChatParseDO> chatParseDOS);
@@ -27,13 +28,11 @@ public interface ChatQueryRepository {
int updateChatQuery(ChatQueryDO chatQueryDO);
Long createChatParse(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq);
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses);
public ChatParseDO getParseInfo(Long questionId, int parseId);
ChatParseDO getParseInfo(Long questionId, int parseId);
List<ChatParseDO> getParseInfoList(List<Long> questionIds);

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.persistence.repository.impl;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
@@ -9,6 +10,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
@@ -28,6 +30,7 @@ import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@Repository
@@ -75,6 +78,20 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
return chatQueryVOPageInfo;
}
@Override
public QueryResp getChatQuery(Long queryId) {
ChatQueryDO chatQueryDO = getChatQueryDO(queryId);
if (Objects.isNull(chatQueryDO)) {
return new QueryResp();
}
return convertTo(chatQueryDO);
}
@Override
public ChatQueryDO getChatQueryDO(Long queryId) {
return chatQueryDOMapper.selectByPrimaryKey(queryId);
}
@Override
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
return showCaseCustomMapper.queryShowCase(pageQueryInfoReq.getLimitStart(),
@@ -84,33 +101,22 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
private QueryResp convertTo(ChatQueryDO chatQueryDO) {
QueryResp queryResponse = new QueryResp();
BeanUtils.copyProperties(chatQueryDO, queryResponse);
QueryResp queryResp = new QueryResp();
BeanUtils.copyProperties(chatQueryDO, queryResp);
QueryResult queryResult = JsonUtil.toObject(chatQueryDO.getQueryResult(), QueryResult.class);
if (queryResult != null) {
queryResult.setQueryId(chatQueryDO.getQuestionId());
queryResponse.setQueryResult(queryResult);
queryResp.setQueryResult(queryResult);
}
return queryResponse;
if (StringUtils.isNotBlank(chatQueryDO.getSimilarQueries())) {
List<SimilarQueryRecallResp> similarQueries = JSONObject.parseArray(chatQueryDO.getSimilarQueries(),
SimilarQueryRecallResp.class);
queryResp.setSimilarQueries(similarQueries);
}
return queryResp;
}
@Override
public void createChatQuery(QueryResult queryResult, ChatContext chatCtx) {
ChatQueryDO chatQueryDO = new ChatQueryDO();
chatQueryDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatQueryDO.setCreateTime(new java.util.Date());
chatQueryDO.setUserName(chatCtx.getUser());
chatQueryDO.setQueryState(queryResult.getQueryState().ordinal());
chatQueryDO.setQueryText(chatCtx.getQueryText());
chatQueryDO.setQueryResult(JsonUtil.toString(queryResult));
chatQueryDO.setAgentId(chatCtx.getAgentId());
chatQueryDOMapper.insert(chatQueryDO);
ChatQueryDO lastChatQuery = getLastChatQuery(chatCtx.getChatId());
Long queryId = lastChatQuery.getQuestionId();
queryResult.setQueryId(queryId);
}
public Long createChatParse(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq) {
public Long createChatQuery(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq) {
ChatQueryDO chatQueryDO = new ChatQueryDO();
chatQueryDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatQueryDO.setCreateTime(new java.util.Date());
@@ -132,7 +138,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses) {
Long queryId = createChatParse(parseResult, chatCtx, queryReq);
Long queryId = createChatQuery(parseResult, chatCtx, queryReq);
List<ChatParseDO> chatParseDOList = new ArrayList<>();
getChatParseDO(chatCtx, queryReq, queryId, candidateParses, chatParseDOList);
chatParseMapper.batchSaveParseInfo(chatParseDOList);

View File

@@ -0,0 +1,73 @@
package com.tencent.supersonic.chat.processor.execute;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import org.apache.commons.compress.utils.Lists;
import org.springframework.util.CollectionUtils;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* DrillDownDimensionProcessor obtains metric recommended dimensions
*/
public class DimensionRecommendProcessor implements ExecuteResultProcessor {
private static final int recommend_dimension_size = 5;
@Override
public void process(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
return;
}
SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getModel());
queryResult.setRecommendedDimensions(dimensionRecommended);
}
private List<SchemaElement> getDimensions(Long metricId, Long modelId) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
List<Long> drillDownDimensions = Lists.newArrayList();
Set<SchemaElement> metricElements = modelSchema.getMetrics();
if (!CollectionUtils.isEmpty(metricElements)) {
Optional<SchemaElement> metric = metricElements.stream().filter(schemaElement ->
metricId.equals(schemaElement.getId())
&& !CollectionUtils.isEmpty(schemaElement.getRelatedSchemaElements()))
.findFirst();
if (metric.isPresent()) {
drillDownDimensions = metric.get().getRelatedSchemaElements().stream()
.map(RelatedSchemaElement::getDimensionId).collect(Collectors.toList());
}
}
final List<Long> drillDownDimensionsFinal = drillDownDimensions;
return modelSchema.getDimensions().stream()
.filter(dim -> filterDimension(drillDownDimensionsFinal, dim))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(recommend_dimension_size)
.collect(Collectors.toList());
}
private boolean filterDimension(List<Long> drillDownDimensions, SchemaElement dimension) {
if (Objects.isNull(dimension)) {
return false;
}
if (!CollectionUtils.isEmpty(drillDownDimensions)) {
return drillDownDimensions.contains(dimension.getId());
}
return Objects.nonNull(dimension.getUseCnt());
}
}

View File

@@ -0,0 +1,73 @@
package com.tencent.supersonic.chat.processor.execute;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import org.apache.commons.compress.utils.Lists;
import org.springframework.util.CollectionUtils;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* DrillDownDimensionProcessor obtains metric recommended dimensions based on setting
*/
public class DrillDownDimensionProcessor implements ExecuteResultProcessor {
private static final int recommend_dimension_size = 5;
@Override
public void process(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
return;
}
SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
List<SchemaElement> dimensionRecommend = getDimensions(element.getId(), element.getModel());
queryResult.setRecommendDimensions(dimensionRecommend);
}
private List<SchemaElement> getDimensions(Long metricId, Long modelId) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
List<Long> drillDownDimensions = Lists.newArrayList();
Set<SchemaElement> metricElements = modelSchema.getMetrics();
if (!CollectionUtils.isEmpty(metricElements)) {
Optional<SchemaElement> metric = metricElements.stream().filter(schemaElement ->
metricId.equals(schemaElement.getId())
&& !CollectionUtils.isEmpty(schemaElement.getRelatedSchemaElements()))
.findFirst();
if (metric.isPresent()) {
drillDownDimensions = metric.get().getRelatedSchemaElements().stream()
.map(RelatedSchemaElement::getDimensionId).collect(Collectors.toList());
}
}
final List<Long> drillDownDimensionsFinal = drillDownDimensions;
return modelSchema.getDimensions().stream()
.filter(dim -> filterDimension(drillDownDimensionsFinal, dim))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(recommend_dimension_size)
.collect(Collectors.toList());
}
private boolean filterDimension(List<Long> drillDownDimensions, SchemaElement dimension) {
if (Objects.isNull(dimension)) {
return false;
}
if (!CollectionUtils.isEmpty(drillDownDimensions)) {
return drillDownDimensions.contains(dimension.getId());
}
return Objects.nonNull(dimension.getUseCnt());
}
}

View File

@@ -0,0 +1,86 @@
package com.tencent.supersonic.chat.processor.parse;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.PageInfo;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.utils.SimilarQueryManager;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
@Slf4j
public class SimilarQueryProcessor implements ParseResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
CompletableFuture.runAsync(() -> doProcess(parseResp, queryContext));
}
@SneakyThrows
private void doProcess(ParseResp parseResp, QueryContext queryContext) {
Long queryId = parseResp.getQueryId();
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(queryContext.getRequest().getQueryText(),
queryContext.getRequest().getAgentId());
ChatQueryDO chatQueryDO = getChatQuery(queryId);
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
updateChatQuery(chatQueryDO);
}
public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) {
//1. recall solved query by queryText
SimilarQueryManager solvedQueryManager = ContextUtils.getBean(SimilarQueryManager.class);
List<SimilarQueryRecallResp> similarQueries = solvedQueryManager.recallSimilarQuery(queryText, agentId);
if (CollectionUtils.isEmpty(similarQueries)) {
return Lists.newArrayList();
}
//2. remove low score query
List<Long> queryIds = similarQueries.stream()
.map(SimilarQueryRecallResp::getQueryId).collect(Collectors.toList());
int lowScoreThreshold = 3;
List<QueryResp> queryResps = getChatQuery(queryIds);
if (CollectionUtils.isEmpty(queryResps)) {
return Lists.newArrayList();
}
Set<Long> lowScoreQueryIds = queryResps.stream().filter(queryResp ->
queryResp.getScore() != null && queryResp.getScore() <= lowScoreThreshold)
.map(QueryResp::getQuestionId).collect(Collectors.toSet());
return similarQueries.stream().filter(solvedQueryRecallResp ->
!lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId()))
.collect(Collectors.toList());
}
private ChatQueryDO getChatQuery(Long queryId) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
return chatQueryRepository.getChatQueryDO(queryId);
}
private List<QueryResp> getChatQuery(List<Long> queryIds) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq();
pageQueryInfoReq.setIds(queryIds);
pageQueryInfoReq.setPageSize(100);
pageQueryInfoReq.setCurrent(1);
PageInfo<QueryResp> queryRespPageInfo = chatQueryRepository.getChatQuery(pageQueryInfoReq, null);
return queryRespPageInfo.getList();
}
private void updateChatQuery(ChatQueryDO chatQueryDO) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
chatQueryRepository.updateChatQuery(chatQueryDO);
}
}

View File

@@ -3,22 +3,23 @@ package com.tencent.supersonic.chat.rest;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.response.QueryRecallResp;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryRecallResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.service.ChatService;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.List;
@RestController
@RequestMapping({"/api/chat/manage", "/openapi/chat/manage"})
@@ -81,6 +82,11 @@ public class ChatController {
return chatService.queryInfo(pageQueryInfoCommand, chatId);
}
@GetMapping("/getChatQuery/{queryId}")
public QueryResp getChatQuery(@PathVariable("queryId") Long queryId) {
return chatService.getChatQuery(queryId);
}
@PostMapping("/queryShowCase")
public ShowCaseResp queryShowCase(@RequestBody PageQueryInfoReq pageQueryInfoCommand,
@RequestParam(value = "agentId") int agentId) {
@@ -88,11 +94,11 @@ public class ChatController {
}
@RequestMapping("/getSolvedQuery")
public List<SolvedQueryRecallResp> getSolvedQuery(@RequestParam(value = "queryText") String queryText,
@RequestParam(value = "agentId") Integer agentId) {
public List<SimilarQueryRecallResp> getSolvedQuery(@RequestParam(value = "queryText") String queryText,
@RequestParam(value = "agentId") Integer agentId) {
QueryRecallResp queryRecallResp = new QueryRecallResp();
Long startTime = System.currentTimeMillis();
List<SolvedQueryRecallResp> solvedQueryRecallRespList = chatService.getSolvedQuery(queryText, agentId);
List<SimilarQueryRecallResp> solvedQueryRecallRespList = chatService.getSolvedQuery(queryText, agentId);
queryRecallResp.setSolvedQueryRecallRespList(solvedQueryRecallRespList);
queryRecallResp.setQueryTimeCost(System.currentTimeMillis() - startTime);
return solvedQueryRecallRespList;

View File

@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
@@ -29,8 +29,6 @@ public interface ChatService {
void updateContext(ChatContext chatCtx);
void switchContext(ChatContext chatCtx);
Boolean addChat(User user, String chatName, Integer agentId);
List<ChatDO> getAll(String userName, Integer agentId);
@@ -45,14 +43,12 @@ public interface ChatService {
PageInfo<QueryResp> queryInfo(PageQueryInfoReq pageQueryInfoCommend, long chatId);
QueryResp getChatQuery(Long queryId);
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
void addQuery(QueryResult queryResult, ChatContext chatCtx);
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult);
void updateChatParse(List<ChatParseDO> chatParseDOS);
ChatQueryDO getLastQuery(long chatId);
int updateQuery(ChatQueryDO chatQueryDO);
@@ -63,5 +59,5 @@ public interface ChatService {
Boolean deleteChatQuery(Long questionId);
List<SolvedQueryRecallResp> getSolvedQuery(String queryText, Integer agentId);
List<SimilarQueryRecallResp> getSolvedQuery(String queryText, Integer agentId);
}

View File

@@ -10,7 +10,7 @@ import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
@@ -19,7 +19,7 @@ import com.tencent.supersonic.chat.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.persistence.repository.ChatRepository;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.chat.utils.SimilarQueryManager;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.Lists;
@@ -46,10 +46,10 @@ public class ChatServiceImpl implements ChatService {
private ChatContextRepository chatContextRepository;
private ChatRepository chatRepository;
private ChatQueryRepository chatQueryRepository;
private SolvedQueryManager solvedQueryManager;
private SimilarQueryManager solvedQueryManager;
public ChatServiceImpl(ChatContextRepository chatContextRepository, ChatRepository chatRepository,
ChatQueryRepository chatQueryRepository, SolvedQueryManager solvedQueryManager) {
ChatQueryRepository chatQueryRepository, SimilarQueryManager solvedQueryManager) {
this.chatContextRepository = chatContextRepository;
this.chatRepository = chatRepository;
this.chatQueryRepository = chatQueryRepository;
@@ -83,12 +83,6 @@ public class ChatServiceImpl implements ChatService {
chatContextRepository.updateContext(chatCtx);
}
@Override
public void switchContext(ChatContext chatCtx) {
log.debug("switchContext ChatContext {}", chatCtx);
chatCtx.setParseInfo(new SemanticParseInfo());
}
@Override
public Boolean addChat(User user, String chatName, Integer agentId) {
ChatDO chatDO = new ChatDO();
@@ -142,6 +136,11 @@ public class ChatServiceImpl implements ChatService {
return queryRespPageInfo;
}
@Override
public QueryResp getChatQuery(Long queryId) {
return chatQueryRepository.getChatQuery(queryId);
}
@Override
public ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
ShowCaseResp showCaseResp = new ShowCaseResp();
@@ -196,13 +195,6 @@ public class ChatServiceImpl implements ChatService {
}
}
@Override
public void addQuery(QueryResult queryResult, ChatContext chatCtx) {
chatQueryRepository.createChatQuery(queryResult, chatCtx);
chatRepository.updateLastQuestion(chatCtx.getChatId().longValue(),
chatCtx.getQueryText(), getCurrentTime());
}
@Override
public Boolean updateQuery(Long questionId, QueryResult queryResult, ChatContext chatCtx) {
ChatQueryDO chatQueryDO = new ChatQueryDO();
@@ -226,11 +218,6 @@ public class ChatServiceImpl implements ChatService {
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses);
}
@Override
public void updateChatParse(List<ChatParseDO> chatParseDOS) {
chatQueryRepository.updateChatParseInfo(chatParseDOS);
}
@Override
public ChatQueryDO getLastQuery(long chatId) {
return chatQueryRepository.getLastChatQuery(chatId);
@@ -250,14 +237,14 @@ public class ChatServiceImpl implements ChatService {
}
@Override
public List<SolvedQueryRecallResp> getSolvedQuery(String queryText, Integer agentId) {
public List<SimilarQueryRecallResp> getSolvedQuery(String queryText, Integer agentId) {
//1. recall solved query by queryText
List<SolvedQueryRecallResp> solvedQueryRecallResps = solvedQueryManager.recallSolvedQuery(queryText, agentId);
List<SimilarQueryRecallResp> solvedQueryRecallResps = solvedQueryManager.recallSimilarQuery(queryText, agentId);
if (CollectionUtils.isEmpty(solvedQueryRecallResps)) {
return Lists.newArrayList();
}
List<Long> queryIds = solvedQueryRecallResps.stream()
.map(SolvedQueryRecallResp::getQueryId).collect(Collectors.toList());
.map(SimilarQueryRecallResp::getQueryId).collect(Collectors.toList());
PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq();
pageQueryInfoReq.setIds(queryIds);
pageQueryInfoReq.setPageSize(100);

View File

@@ -19,7 +19,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
@@ -39,7 +39,7 @@ import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.service.StatisticsService;
import com.tencent.supersonic.chat.service.TimeCost;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.chat.utils.SimilarQueryManager;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.QueryType;
@@ -104,7 +104,7 @@ public class QueryServiceImpl implements QueryService {
@Autowired
private StatisticsService statisticsService;
@Autowired
private SolvedQueryManager solvedQueryManager;
private SimilarQueryManager similarQueryManager;
@Value("${time.threshold: 100}")
private Integer timeThreshold;
@@ -240,7 +240,7 @@ public class QueryServiceImpl implements QueryService {
if (queryResult.getResponse() == null && CollectionUtils.isEmpty(queryResult.getQueryResults())) {
return;
}
solvedQueryManager.saveSolvedQuery(SolvedQueryReq.builder().parseId(queryReq.getParseId())
similarQueryManager.saveSimilarQuery(SimilarQueryReq.builder().parseId(queryReq.getParseId())
.queryId(queryReq.getQueryId())
.agentId(chatQueryDO.getAgentId())
.modelId(parseInfo.getModelClusterKey())

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.utils;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
@@ -10,13 +10,6 @@ import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@@ -31,33 +24,41 @@ import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
@Slf4j
@Component
public class SolvedQueryManager {
public class SimilarQueryManager {
private EmbeddingConfig embeddingConfig;
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
public SolvedQueryManager(EmbeddingConfig embeddingConfig) {
public SimilarQueryManager(EmbeddingConfig embeddingConfig) {
this.embeddingConfig = embeddingConfig;
}
public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) {
public void saveSimilarQuery(SimilarQueryReq similarQueryReq) {
if (StringUtils.isBlank(embeddingConfig.getUrl())) {
return;
}
String queryText = solvedQueryReq.getQueryText();
String queryText = similarQueryReq.getQueryText();
try {
String uniqueId = generateUniqueId(solvedQueryReq.getQueryId(), solvedQueryReq.getParseId());
String uniqueId = generateUniqueId(similarQueryReq.getQueryId(), similarQueryReq.getParseId());
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(uniqueId);
embeddingQuery.setQuery(queryText);
Map<String, Object> metaData = new HashMap<>();
metaData.put("modelId", (solvedQueryReq.getModelId()));
metaData.put("agentId", solvedQueryReq.getAgentId());
metaData.put("modelId", (similarQueryReq.getModelId()));
metaData.put("agentId", similarQueryReq.getAgentId());
embeddingQuery.setMetadata(metaData);
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
s2EmbeddingStore.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
@@ -66,11 +67,11 @@ public class SolvedQueryManager {
}
}
public List<SolvedQueryRecallResp> recallSolvedQuery(String queryText, Integer agentId) {
public List<SimilarQueryRecallResp> recallSimilarQuery(String queryText, Integer agentId) {
if (StringUtils.isBlank(embeddingConfig.getUrl())) {
return Lists.newArrayList();
}
List<SolvedQueryRecallResp> solvedQueryRecallResps = Lists.newArrayList();
List<SimilarQueryRecallResp> similarQueryRecallResps = Lists.newArrayList();
try {
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
int solvedQueryResultNum = embeddingConfig.getSolvedQueryResultNum();
@@ -97,11 +98,11 @@ public class SolvedQueryManager {
continue;
}
String id = retrieval.getId();
SolvedQueryRecallResp solvedQueryRecallResp = SolvedQueryRecallResp.builder()
SimilarQueryRecallResp similarQueryRecallResp = SimilarQueryRecallResp.builder()
.queryText(retrieval.getQuery())
.queryId(getQueryId(id)).parseId(getParseId(id))
.build();
solvedQueryRecallResps.add(solvedQueryRecallResp);
similarQueryRecallResps.add(similarQueryRecallResp);
querySet.add(retrieval.getQuery());
}
}
@@ -110,7 +111,7 @@ public class SolvedQueryManager {
} catch (Exception e) {
log.warn("recall similar solved query failed, queryText:{}", queryText);
}
return solvedQueryRecallResps;
return similarQueryRecallResps;
}
private String generateUniqueId(Long queryId, Integer parseId) {

View File

@@ -14,6 +14,7 @@
<resultMap extends="BaseResultMap" id="ResultMapWithBLOBs" type="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO">
<result column="query_text" jdbcType="LONGVARCHAR" property="queryText" />
<result column="query_result" jdbcType="LONGVARCHAR" property="queryResult" />
<result column="similar_queries" jdbcType="LONGVARCHAR" property="similarQueries"/>
</resultMap>
<sql id="Example_Where_Clause">
<where>
@@ -48,7 +49,7 @@
question_id, agent_id, create_time, user_name, query_state, chat_id, score, feedback
</sql>
<sql id="Blob_Column_List">
query_text, query_result
query_text, query_result, similar_queries
</sql>
<select id="selectByExampleWithBLOBs" parameterType="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample" resultMap="ResultMapWithBLOBs">
select
@@ -67,6 +68,11 @@
</if>
</select>
<select id="selectByPrimaryKey" resultMap="ResultMapWithBLOBs">
select * from s2_chat_query
where question_id = #{questionId,jdbcType=BIGINT}
</select>
<delete id="deleteByPrimaryKey" parameterType="java.lang.Long">
delete from s2_chat_query
@@ -75,11 +81,12 @@
<insert id="insert" parameterType="com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO" useGeneratedKeys="true" keyProperty="questionId">
insert into s2_chat_query (agent_id, create_time, user_name,
query_state, chat_id, score,
feedback, query_text, query_result
feedback, query_text, query_result, similar_queries
)
values (#{agentId,jdbcType=INTEGER}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR},
#{queryState,jdbcType=INTEGER}, #{chatId,jdbcType=BIGINT}, #{score,jdbcType=INTEGER},
#{feedback,jdbcType=VARCHAR}, #{queryText,jdbcType=LONGVARCHAR}, #{queryResult,jdbcType=LONGVARCHAR}
#{feedback,jdbcType=VARCHAR}, #{queryText,jdbcType=LONGVARCHAR}, #{queryResult,jdbcType=LONGVARCHAR},
#{similarQueries, jdbcType=LONGVARCHAR}
)
</insert>
@@ -110,6 +117,9 @@
<if test="queryResult != null">
query_result = #{queryResult,jdbcType=LONGVARCHAR},
</if>
<if test="similarQueries != null">
similar_queries = #{similarQueries,jdbcType=LONGVARCHAR},
</if>
</set>
where question_id = #{questionId,jdbcType=BIGINT}
</update>