(improvement)(Chat) Integrate chat with execute parse result processor (#825)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-03-15 16:08:34 +08:00
committed by GitHub
parent 7f3cb5812c
commit 4291ec7bd7
20 changed files with 265 additions and 81 deletions

View File

@@ -0,0 +1,22 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import lombok.Data;
import java.util.HashSet;
import java.util.Set;
@Data
public class ChatQueryDataReq {
private User user;
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();
private Set<QueryFilter> dimensionFilters = new HashSet<>();
private Set<QueryFilter> metricFilters = new HashSet<>();
private DateConf dateInfo;
private Long queryId;
private Integer parseId;
}

View File

@@ -1,10 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
import java.util.ArrayList;
import java.util.List;
import lombok.Data;
@Data
public class AggregateInfo {
private List<MetricInfo> metricInfos = new ArrayList<>();
}

View File

@@ -1,15 +0,0 @@
package com.tencent.supersonic.chat.api.pojo.response;
import java.util.Map;
import lombok.Data;
@Data
public class MetricInfo {
private String name;
private String dimension;
private String value;
private String date;
private Map<String, String> statistics;
}

View File

@@ -0,0 +1,23 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import lombok.Data;
import java.util.Date;
import java.util.List;
@Data
public class QueryResp {
private Long questionId;
private Date createTime;
private Long chatId;
private Integer score;
private String feedback;
private String queryText;
private QueryResult queryResult;
private List<SemanticParseInfo> parseInfos;
private List<SimilarQueryRecallResp> similarQueries;
}

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import lombok.Data;
import java.util.List;

View File

@@ -10,8 +10,6 @@ public class SimilarQueryRecallResp {
private Long queryId;
private Integer parseId;
private String queryText;
}

View File

@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import java.util.List;
@@ -25,6 +25,8 @@ public interface ChatQueryRepository {
int updateChatQuery(ChatQueryDO chatQueryDO);
Long createChatQuery(ChatParseReq chatParseReq);
List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq, ParseResp parseResult,
List<SemanticParseInfo> candidateParses);

View File

@@ -1,9 +1,11 @@
package com.tencent.supersonic.chat.server.persistence.repository.impl;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample;
@@ -16,7 +18,7 @@ import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.PageUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@@ -24,7 +26,6 @@ import org.springframework.beans.BeanUtils;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
@@ -106,10 +107,13 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
queryResult.setQueryId(chatQueryDO.getQuestionId());
queryResp.setQueryResult(queryResult);
}
queryResp.setSimilarQueries(JSONObject.parseArray(chatQueryDO.getSimilarQueries(),
SimilarQueryRecallResp.class));
return queryResp;
}
public Long createChatQuery(ParseResp parseResult, ChatParseReq chatParseReq) {
@Override
public Long createChatQuery(ChatParseReq chatParseReq) {
ChatQueryDO chatQueryDO = new ChatQueryDO();
chatQueryDO.setChatId(Long.valueOf(chatParseReq.getChatId()));
chatQueryDO.setCreateTime(new java.util.Date());
@@ -122,17 +126,14 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
} catch (Exception e) {
log.info("database insert has an exception:{}", e.toString());
}
Long queryId = chatQueryDO.getQuestionId();
parseResult.setQueryId(queryId);
return queryId;
return chatQueryDO.getQuestionId();
}
@Override
public List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq,
ParseResp parseResult, List<SemanticParseInfo> candidateParses) {
Long queryId = createChatQuery(parseResult, chatParseReq);
List<ChatParseDO> chatParseDOList = new ArrayList<>();
getChatParseDO(chatParseReq, queryId, candidateParses, chatParseDOList);
getChatParseDO(chatParseReq, parseResult.getQueryId(), candidateParses, chatParseDOList);
if (!CollectionUtils.isEmpty(candidateParses)) {
chatParseMapper.batchSaveParseInfo(chatParseDOList);
}

View File

@@ -1,15 +1,15 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import org.springframework.util.CollectionUtils;
import java.util.Comparator;
@@ -28,14 +28,15 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
private static final int recommend_dimension_size = 5;
@Override
public void process(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
return;
}
//SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
//List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getDataSet());
//queryResult.setRecommendedDimensions(dimensionRecommended);
SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getDataSet());
queryResult.setRecommendedDimensions(dimensionRecommended);
}
private List<SchemaElement> getDimensions(Long metricId, Long dataSetId) {

View File

@@ -1,15 +1,14 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
/**
* A ExecuteResultProcessor wraps things up before returning results to users in execute stage.
*/
public interface ExecuteResultProcessor extends ResultProcessor {
void process(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq);
void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult);
}

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.response.AggregateInfo;
import com.tencent.supersonic.chat.api.pojo.response.MetricInfo;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.QueryColumn;
@@ -11,14 +10,17 @@ import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.RatioOverType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.AggregateInfo;
import com.tencent.supersonic.headless.api.pojo.MetricInfo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.config.AggregatorConfig;
import com.tencent.supersonic.headless.core.utils.QueryReqBuilder;
import com.tencent.supersonic.headless.server.service.QueryService;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@@ -56,19 +58,18 @@ import static com.tencent.supersonic.common.pojo.Constants.WEEK;
@Slf4j
public class MetricRatioProcessor implements ExecuteResultProcessor {
//private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
@Override
public void process(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class);
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|| !aggregatorConfig.getEnableRatio()
|| !QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
return;
}
//AggregateInfo aggregateInfo = getAggregateInfo(queryReq.getUser(), semanticParseInfo, queryResult);
//queryResult.setAggregateInfo(aggregateInfo);
AggregateInfo aggregateInfo = getAggregateInfo(chatExecuteContext.getUser(),
semanticParseInfo, queryResult);
queryResult.setAggregateInfo(aggregateInfo);
}
public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo, QueryResult queryResult) {
@@ -123,16 +124,17 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
return aggregateInfo;
}
@SneakyThrows
private MetricInfo queryRatio(User user, SemanticParseInfo semanticParseInfo, SchemaElement metric,
AggOperatorEnum aggOperatorEnum, QueryResult queryResult) {
AggOperatorEnum aggOperatorEnum, QueryResult queryResult) {
QueryStructReq queryStructReq = QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum);
String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo());
queryStructReq.setGroups(new ArrayList<>(Arrays.asList(dateField)));
queryStructReq.setDateInfo(getRatioDateConf(aggOperatorEnum, semanticParseInfo, queryResult));
queryStructReq.setConvertToSql(false);
SemanticQueryResp queryResp = null;
QueryService queryService = ContextUtils.getBean(QueryService.class);
SemanticQueryResp queryResp = queryService.queryByReq(queryStructReq, user);
MetricInfo metricInfo = new MetricInfo();
metricInfo.setStatistics(new HashMap<>());
if (Objects.isNull(queryResp) || CollectionUtils.isEmpty(queryResp.getResultList())) {

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -10,9 +11,10 @@ import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.core.chat.knowledge.MetaEmbeddingService;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
@@ -21,7 +23,6 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.springframework.util.CollectionUtils;
/**
* MetricRecommendProcessor fills recommended metrics based on embedding similarity.
@@ -31,7 +32,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
private static final int METRIC_RECOMMEND_SIZE = 5;
@Override
public void process(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
fillSimilarMetric(queryResult.getChatContext());
}

View File

@@ -8,9 +8,10 @@ import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.SimilarQueryManager;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@@ -35,7 +36,7 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
private void doProcess(ParseResp parseResp, ChatParseContext chatParseContext) {
Long queryId = parseResp.getQueryId();
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(chatParseContext.getQueryText(),
null);
chatParseContext.getAgent().getId());
ChatQueryDO chatQueryDO = getChatQuery(queryId);
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
updateChatQuery(chatQueryDO);
@@ -43,8 +44,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) {
//1. recall solved query by queryText
//SimilarQueryManager solvedQueryManager = ContextUtils.getBean(SimilarQueryManager.class);
List<SimilarQueryRecallResp> similarQueries = Lists.newArrayList();
SimilarQueryManager solvedQueryManager = ContextUtils.getBean(SimilarQueryManager.class);
List<SimilarQueryRecallResp> similarQueries = solvedQueryManager.recallSimilarQuery(queryText, agentId);
if (CollectionUtils.isEmpty(similarQueries)) {
return Lists.newArrayList();
}

View File

@@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.server.rest;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.server.service.ChatService;

View File

@@ -4,9 +4,9 @@ package com.tencent.supersonic.chat.server.rest;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
@@ -58,10 +58,10 @@ public class ChatQueryController {
}
@PostMapping("queryData")
public Object queryData(@RequestBody QueryDataReq queryData,
public Object queryData(@RequestBody ChatQueryDataReq chatQueryDataReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
queryData.setUser(UserHolder.findUser(request, response));
return chatService.queryData(queryData, UserHolder.findUser(request, response));
chatQueryDataReq.setUser(UserHolder.findUser(request, response));
return chatService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
}
@PostMapping("queryDimensionValue")

View File

@@ -4,16 +4,16 @@ import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
@@ -27,7 +27,7 @@ public interface ChatService {
QueryResult performExecution(ChatExecuteReq chatExecuteReq) throws Exception;
Object queryData(QueryDataReq queryData, User user) throws Exception;
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
SemanticParseInfo queryContext(Integer chatId);
@@ -53,13 +53,7 @@ public interface ChatService {
List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ParseResp parseResult);
ChatQueryDO getLastQuery(long chatId);
int updateQuery(ChatQueryDO chatQueryDO);
void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult);
ChatParseDO getParseInfo(Long questionId, int parseId);
Boolean deleteChatQuery(Long questionId);
}

View File

@@ -5,7 +5,9 @@ import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
@@ -18,11 +20,13 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryReposi
import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.chat.server.util.SimilarQueryManager;
import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
@@ -32,7 +36,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.server.service.ChatQueryService;
@@ -63,9 +67,12 @@ public class ChatServiceImpl implements ChatService {
private ChatQueryService chatQueryService;
@Autowired
private SearchService searchService;
@Autowired
private SimilarQueryManager similarQueryManager;
private List<ChatParser> chatParsers = ComponentFactory.getChatParsers();
private List<ChatExecutor> chatExecutors = ComponentFactory.getChatExecutors();
private List<ParseResultProcessor> parseResultProcessors = ComponentFactory.getParseProcessors();
private List<ExecuteResultProcessor> executeResultProcessors = ComponentFactory.getExecuteProcessors();
@Override
public List<SearchResult> search(ChatParseReq chatParseReq) {
@@ -77,6 +84,7 @@ public class ChatServiceImpl implements ChatService {
@Override
public ParseResp performParsing(ChatParseReq chatParseReq) {
ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText());
createChatQuery(chatParseReq, parseResp);
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
for (ChatParser chatParser : chatParsers) {
chatParser.parse(chatParseContext, parseResp);
@@ -98,6 +106,9 @@ public class ChatServiceImpl implements ChatService {
break;
}
}
for (ExecuteResultProcessor processor : executeResultProcessors) {
processor.process(chatExecuteContext, queryResult);
}
saveQueryResult(chatExecuteReq, queryResult);
return queryResult;
}
@@ -117,15 +128,18 @@ public class ChatServiceImpl implements ChatService {
private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
ChatExecuteContext chatExecuteContext = new ChatExecuteContext();
BeanMapper.mapper(chatExecuteReq, chatExecuteContext);
ChatParseDO chatParseDO = getParseInfo(chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
SemanticParseInfo semanticParseInfo = JSONObject.parseObject(chatParseDO.getParseInfo(),
SemanticParseInfo.class);
chatExecuteContext.setParseInfo(semanticParseInfo);
SemanticParseInfo parseInfo = getParseInfo(chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
chatExecuteContext.setParseInfo(parseInfo);
return chatExecuteContext;
}
@Override
public Object queryData(QueryDataReq queryData, User user) throws Exception {
public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception {
Integer parseId = chatQueryDataReq.getParseId();
SemanticParseInfo parseInfo = getParseInfo(chatQueryDataReq.getQueryId(), parseId);
QueryDataReq queryData = new QueryDataReq();
BeanMapper.mapper(chatQueryDataReq, queryData);
queryData.setParseInfo(parseInfo);
return chatQueryService.executeDirectQuery(queryData, user);
}
@@ -192,6 +206,11 @@ public class ChatServiceImpl implements ChatService {
return queryRespPageInfo;
}
public void createChatQuery(ChatParseReq chatParseReq, ParseResp parseResp) {
Long queryId = chatQueryRepository.createChatQuery(chatParseReq);
parseResp.setQueryId(queryId);
}
@Override
public QueryResp getChatQuery(Long queryId) {
return chatQueryRepository.getChatQuery(queryId);
@@ -257,13 +276,16 @@ public class ChatServiceImpl implements ChatService {
if (chatExecuteReq.getParseId() > 1) {
return;
}
ChatQueryDO chatQueryDO = new ChatQueryDO();
ChatQueryDO chatQueryDO = chatQueryRepository.getChatQueryDO(chatExecuteReq.getQueryId());
chatQueryDO.setQuestionId(chatExecuteReq.getQueryId());
chatQueryDO.setQueryResult(JsonUtil.toString(queryResult));
chatQueryDO.setQueryState(1);
updateQuery(chatQueryDO);
chatRepository.updateLastQuestion(chatExecuteReq.getChatId().longValue(),
chatExecuteReq.getQueryText(), getCurrentTime());
SimilarQueryReq similarQueryReq = SimilarQueryReq.builder().queryId(chatExecuteReq.getQueryId())
.queryText(chatQueryDO.getQueryText()).agentId(chatQueryDO.getAgentId()).build();
similarQueryManager.saveSimilarQuery(similarQueryReq);
}
@Override
@@ -277,22 +299,14 @@ public class ChatServiceImpl implements ChatService {
return chatQueryRepository.batchSaveParseInfo(chatParseReq, parseResult, candidateParses);
}
@Override
public ChatQueryDO getLastQuery(long chatId) {
return chatQueryRepository.getLastChatQuery(chatId);
}
private String getCurrentTime() {
SimpleDateFormat tempDate = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
return tempDate.format(new java.util.Date());
}
public ChatParseDO getParseInfo(Long questionId, int parseId) {
return chatQueryRepository.getParseInfo(questionId, parseId);
}
public Boolean deleteChatQuery(Long questionId) {
return chatQueryRepository.deleteChatQuery(questionId);
public SemanticParseInfo getParseInfo(Long questionId, int parseId) {
ChatParseDO chatParseDO = chatQueryRepository.getParseInfo(questionId, parseId);
return JSONObject.parseObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
}
}

View File

@@ -0,0 +1,142 @@
package com.tencent.supersonic.chat.server.util;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@Component
public class SimilarQueryManager {
private EmbeddingConfig embeddingConfig;
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
public SimilarQueryManager(EmbeddingConfig embeddingConfig) {
this.embeddingConfig = embeddingConfig;
}
public void saveSimilarQuery(SimilarQueryReq similarQueryReq) {
if (StringUtils.isBlank(embeddingConfig.getUrl())) {
return;
}
String queryText = similarQueryReq.getQueryText();
try {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(String.valueOf(similarQueryReq.getQueryId()));
embeddingQuery.setQuery(queryText);
Map<String, Object> metaData = new HashMap<>();
metaData.put("agentId", similarQueryReq.getAgentId());
embeddingQuery.setMetadata(metaData);
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
s2EmbeddingStore.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
} catch (Exception e) {
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
}
}
public List<SimilarQueryRecallResp> recallSimilarQuery(String queryText, Integer agentId) {
if (StringUtils.isBlank(embeddingConfig.getUrl())) {
return Lists.newArrayList();
}
List<SimilarQueryRecallResp> similarQueryRecallResps = Lists.newArrayList();
try {
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
int solvedQueryResultNum = embeddingConfig.getSolvedQueryResultNum();
Map<String, String> filterCondition = new HashMap<>();
filterCondition.put("agentId", String.valueOf(agentId));
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
.queryTextsList(Lists.newArrayList(queryText))
.filterCondition(filterCondition)
.build();
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(solvedQueryCollection, retrieveQuery,
solvedQueryResultNum * 20);
log.info("[embedding] recognize result body:{}", resultList);
Set<String> querySet = new HashSet<>();
if (CollectionUtils.isNotEmpty(resultList)) {
for (RetrieveQueryResult retrieveQueryResult : resultList) {
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
for (Retrieval retrieval : retrievals) {
if (queryText.equalsIgnoreCase(retrieval.getQuery())) {
continue;
}
if (querySet.contains(retrieval.getQuery())) {
continue;
}
String id = retrieval.getId();
SimilarQueryRecallResp similarQueryRecallResp = SimilarQueryRecallResp.builder()
.queryText(retrieval.getQuery())
.queryId(Long.parseLong(id))
.build();
similarQueryRecallResps.add(similarQueryRecallResp);
querySet.add(retrieval.getQuery());
}
}
}
} catch (Exception e) {
log.warn("recall similar solved query failed, queryText:{}", queryText, e);
}
return similarQueryRecallResps.stream()
.limit(embeddingConfig.getSolvedQueryResultNum()).collect(Collectors.toList());
}
private ResponseEntity<String> doRequest(String path, String jsonBody) {
if (Strings.isEmpty(embeddingConfig.getUrl())) {
return ResponseEntity.of(Optional.empty());
}
String url = embeddingConfig.getUrl() + path;
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
URI requestUrl = UriComponentsBuilder
.fromHttpUrl(url).build().encode().toUri();
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
log.info("[embedding] request body :{}, url:{}", jsonBody, url);
RestTemplate restTemplate = new RestTemplate();
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {
});
log.info("[embedding] result body:{}", responseEntity);
return responseEntity;
} catch (Exception e) {
log.warn("connect to embedding service failed, url:{}", url);
}
return ResponseEntity.of(Optional.empty());
}
}