(improvement)(chat) Remove QueryReq parameter from QueryContext. (#656)

This commit is contained in:
lexluo09
2024-01-19 16:17:31 +08:00
committed by GitHub
parent f017f41201
commit cbf38ed785
35 changed files with 115 additions and 152 deletions

View File

@@ -28,7 +28,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
@Override
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
String text = queryContext.getRequest().getQueryText();
String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
@@ -44,7 +44,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
String text = queryContext.getRequest().getQueryText();
String text = queryContext.getQueryText();
Set<T> results = new HashSet<>();
Set<String> detectSegments = new HashSet<>();
@@ -102,7 +102,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
}
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest(), queryContext.getAgent());
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getModelId(), queryContext.getAgent());
terms = filterByModelIds(terms, detectModelIds);
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
List<T> matches = new ArrayList<>();

View File

@@ -55,11 +55,11 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectModelIds,
Integer startIndex, Integer index, int offset) {
String detectSegment = queryContext.getRequest().getQueryText().substring(startIndex, index);
String detectSegment = queryContext.getQueryText().substring(startIndex, index);
if (StringUtils.isBlank(detectSegment)) {
return;
}
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest(), queryContext.getAgent());
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getModelId(), queryContext.getAgent());
Double metricDimensionThresholdConfig = getThreshold(queryContext);

View File

@@ -23,7 +23,7 @@ public class EmbeddingMapper extends BaseMapper {
@Override
public void doMap(QueryContext queryContext) {
//1. query from embedding by queryText
String queryText = queryContext.getRequest().getQueryText();
String queryText = queryContext.getQueryText();
List<Term> terms = HanlpHelper.getTerms(queryText);
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);

View File

@@ -1,11 +1,10 @@
package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.chat.core.knowledge.SearchService;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import java.util.HashMap;
import java.util.LinkedHashSet;
@@ -38,8 +37,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> terms,
Set<Long> detectModelIds) {
QueryReq queryReq = queryContext.getRequest();
String text = queryReq.getQueryText();
String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
@@ -61,9 +59,8 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
Integer startIndex, Integer index, int offset) {
QueryReq queryReq = queryContext.getRequest();
String text = queryReq.getQueryText();
Integer agentId = queryReq.getAgentId();
String text = queryContext.getQueryText();
Integer agentId = queryContext.getAgentId();
String detectSegment = text.substring(startIndex, index);
// step1. pre search

View File

@@ -29,7 +29,7 @@ public class KeywordMapper extends BaseMapper {
@Override
public void doMap(QueryContext queryContext) {
String queryText = queryContext.getRequest().getQueryText();
String queryText = queryContext.getQueryText();
//1.hanlpDict Match
List<Term> terms = HanlpHelper.getTerms(queryText);
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.algorithm.EditDistance;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.utils.NatureHelper;
@@ -82,9 +81,8 @@ public class MapperHelper {
detectSegment.length());
}
public Set<Long> getModelIds(QueryReq request, Agent agent) {
public Set<Long> getModelIds(Long modelId, Agent agent) {
Long modelId = request.getModelId();
Set<Long> detectModelIds = new HashSet<>();
if (Objects.nonNull(agent)) {
detectModelIds = agent.getModelIds(null);

View File

@@ -1,15 +1,14 @@
package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import java.util.List;
import java.util.Map;
@@ -24,8 +23,7 @@ public class QueryFilterMapper implements SchemaMapper {
@Override
public void map(QueryContext queryContext) {
QueryReq queryReq = queryContext.getRequest();
Long modelId = queryReq.getModelId();
Long modelId = queryContext.getModelId();
if (modelId == null || modelId <= 0) {
return;
}
@@ -62,7 +60,7 @@ public class QueryFilterMapper implements SchemaMapper {
.name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE)
.bizName(filter.getBizName())
.model(queryContext.getRequest().getModelId())
.model(queryContext.getModelId())
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.chat.core.knowledge.SearchService;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
@@ -29,8 +28,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> originals,
Set<Long> detectModelIds) {
QueryReq queryReq = queryContext.getRequest();
String text = queryReq.getQueryText();
String text = queryContext.getQueryText();
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
List<Integer> detectIndexList = Lists.newArrayList();
@@ -54,9 +52,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
if (StringUtils.isNotEmpty(detectSegment)) {
List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment,
SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
SearchService.SEARCH_SIZE, queryContext.getAgentId(), detectModelIds);
List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(
detectSegment, SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
detectSegment, SEARCH_SIZE, queryContext.getAgentId(), detectModelIds);
hanlpMapResults.addAll(suffixHanlpMapResults);
// remove entity name where search
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {

View File

@@ -31,7 +31,7 @@ public class QueryTypeParser implements SemanticParser {
public void parse(QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
User user = queryContext.getRequest().getUser();
User user = queryContext.getUser();
for (SemanticQuery semanticQuery : candidateQueries) {
// 1.init S2SQL

View File

@@ -23,7 +23,7 @@ public class SatisfactionChecker {
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
continue;
}
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
if (checkThreshold(queryContext.getQueryText(), query.getParseInfo())) {
return true;
}
}

View File

@@ -2,30 +2,29 @@ package com.tencent.supersonic.chat.core.parser.plugin;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.plugin.PluginManager;
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.util.CollectionUtils;
/**
@@ -36,7 +35,7 @@ public abstract class PluginParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
if (queryContext.getRequest().getQueryText().length() <= semanticQuery.getParseInfo().getScore()
if (queryContext.getQueryText().length() <= semanticQuery.getParseInfo().getScore()
&& (QueryManager.getPluginQueryModes().contains(semanticQuery.getQueryMode()))) {
return;
}
@@ -64,8 +63,7 @@ public abstract class PluginParser implements SemanticParser {
for (Long modelId : modelIds) {
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin,
queryContext.getRequest(),
queryContext.getModelClusterMapInfo().getMatchedElements(modelId),
queryContext.getQueryFilters(), queryContext.getModelClusterMapInfo().getMatchedElements(modelId),
pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
semanticParseInfo.setScore(pluginRecallResult.getScore());
@@ -78,7 +76,7 @@ public abstract class PluginParser implements SemanticParser {
return PluginManager.getPluginAgentCanSupport(queryContext);
}
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryReq queryReq,
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryFilters queryFilters,
List<SchemaElementMatch> schemaElementMatches, double distance) {
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
modelId = plugin.getModelList().get(0);
@@ -92,7 +90,7 @@ public abstract class PluginParser implements SemanticParser {
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);
pluginParseResult.setRequest(queryReq);
pluginParseResult.setQueryFilters(queryFilters);
pluginParseResult.setDistance(distance);
properties.put(Constants.CONTEXT, pluginParseResult);
properties.put("type", "plugin");

View File

@@ -42,7 +42,7 @@ public class EmbeddingRecallParser extends PluginParser {
@Override
public PluginRecallResult recallPlugin(QueryContext queryContext) {
String text = queryContext.getRequest().getQueryText();
String text = queryContext.getQueryText();
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return null;
@@ -63,7 +63,7 @@ public class EmbeddingRecallParser extends PluginParser {
}
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double distance = embeddingRetrieval.getDistance();
double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
double score = queryContext.getQueryText().length() * (1 - distance);
return PluginRecallResult.builder()
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
}

View File

@@ -33,7 +33,7 @@ public class FunctionCallParser extends PluginParser {
String functionUrl = functionCallConfig.getUrl();
if (StringUtils.isBlank(functionUrl) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
queryContext.getRequest().getQueryText());
queryContext.getQueryText());
return false;
}
List<Plugin> plugins = getPluginList(queryContext);
@@ -60,7 +60,7 @@ public class FunctionCallParser extends PluginParser {
if (CollectionUtils.isEmpty(modelList)) {
return null;
}
double score = queryContext.getRequest().getQueryText().length();
double score = queryContext.getQueryText().length();
return PluginRecallResult.builder().plugin(plugin).modelIds(modelList).score(score).build();
}
return null;
@@ -68,7 +68,7 @@ public class FunctionCallParser extends PluginParser {
public FunctionResp functionCall(QueryContext queryContext) {
List<PluginParseConfig> pluginToFunctionCall =
getPluginToFunctionCall(queryContext.getRequest().getModelId(), queryContext);
getPluginToFunctionCall(queryContext.getModelId(), queryContext);
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
log.info("function call parser, plugin is empty, skip");
return null;
@@ -78,7 +78,7 @@ public class FunctionCallParser extends PluginParser {
functionResp.setToolSelection(pluginToFunctionCall.iterator().next().getName());
} else {
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryContext.getRequest().getQueryText())
.queryText(queryContext.getQueryText())
.pluginConfigs(pluginToFunctionCall).build();
functionResp = ComponentFactory.getLLMProxy().requestFunction(functionReq);
}

View File

@@ -116,7 +116,7 @@ public class HeuristicModelResolver implements ModelResolver {
public String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
SchemaModelClusterMapInfo mapInfo = queryContext.getModelClusterMapInfo();
Set<String> matchedModelClusters = mapInfo.getElementMatchesByModelIds(restrictiveModels).keySet();
Long modelId = queryContext.getRequest().getModelId();
Long modelId = queryContext.getModelId();
if (Objects.nonNull(modelId) && modelId > 0) {
if (CollectionUtils.isEmpty(restrictiveModels) || restrictiveModels.contains(modelId)) {
return getModelClusterByModelId(modelId, matchedModelClusters);

View File

@@ -57,7 +57,7 @@ public class LLMRequestService {
return true;
}
if (SatisfactionChecker.isSkip(queryCtx)) {
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getRequest().getQueryText());
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getQueryText());
return true;
}
return false;
@@ -121,7 +121,7 @@ public class LLMRequestService {
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
ModelCluster modelCluster, List<ElementValue> linkingValues) {
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
String queryText = queryCtx.getRequest().getQueryText();
String queryText = queryCtx.getQueryText();
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);

View File

@@ -39,7 +39,7 @@ public class LLMResponseService {
properties.put("name", commonAgentTool.getName());
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setS2SQL(s2SQL);
parseInfo.setModel(parseResult.getModelCluster());

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
@@ -24,7 +23,6 @@ public class LLMSqlParser implements SemanticParser {
@Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
QueryReq request = queryCtx.getRequest();
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
//1.determine whether to skip this parser.
if (requestService.isSkip(queryCtx)) {
@@ -56,7 +54,6 @@ public class LLMSqlParser implements SemanticParser {
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
ParseResult parseResult = ParseResult.builder()
.request(request)
.modelCluster(modelCluster)
.commonAgentTool(commonAgentTool)
.llmReq(llmReq)

View File

@@ -49,7 +49,7 @@ public class AggregateTypeParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
String queryText = queryContext.getRequest().getQueryText();
String queryText = queryContext.getQueryText();
AggregateConf aggregateConf = resolveAggregateConf(queryText);
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {

View File

@@ -42,12 +42,13 @@ public class TimeRangeParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
DateConf dateConf = parseRecent(queryContext.getRequest().getQueryText());
String queryText = queryContext.getQueryText();
DateConf dateConf = parseRecent(queryText);
if (dateConf == null) {
dateConf = parseDateNumber(queryContext.getRequest().getQueryText());
dateConf = parseDateNumber(queryText);
}
if (dateConf == null) {
dateConf = parseDateCN(queryContext.getRequest().getQueryText());
dateConf = parseDateCN(queryText);
}
if (dateConf != null) {

View File

@@ -1,12 +1,12 @@
package com.tencent.supersonic.chat.core.plugin;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import lombok.Data;
@Data
public class PluginParseResult {
private Plugin plugin;
private QueryReq request;
private QueryFilters queryFilters;
private double distance;
}

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.chat.core.pojo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
@@ -27,7 +27,12 @@ import lombok.NoArgsConstructor;
@AllArgsConstructor
public class QueryContext {
private QueryReq request;
private String queryText;
private Integer chatId;
private Long modelId;
private User user;
private boolean saveAnswer = true;
private Integer agentId;
private QueryFilters queryFilters;
private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo();

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
import com.tencent.supersonic.chat.core.query.BaseSemanticQuery;
import lombok.extern.slf4j.Slf4j;
@@ -31,11 +30,10 @@ public abstract class PluginSemanticQuery extends BaseSemanticQuery {
private Map<Long, Object> getFilterMap(PluginParseResult pluginParseResult) {
Map<Long, Object> map = new HashMap<>();
QueryReq queryReq = pluginParseResult.getRequest();
if (queryReq == null || queryReq.getQueryFilters() == null) {
QueryFilters queryFilters = pluginParseResult.getQueryFilters();
if (queryFilters == null) {
return map;
}
QueryFilters queryFilters = queryReq.getQueryFilters();
List<QueryFilter> queryFilterList = queryFilters.getFilters();
if (CollectionUtils.isEmpty(queryFilterList)) {
return map;

View File

@@ -36,7 +36,7 @@ public class MetricTopNQuery extends MetricSemanticQuery {
@Override
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
QueryContext queryCtx) {
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getRequest().getQueryText());
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getQueryText());
if (matcher.matches()) {
return super.match(candidateElementMatches, queryCtx);
}

View File

@@ -1,15 +1,14 @@
package com.tencent.supersonic.chat.server.persistence.repository;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import java.util.List;
public interface ChatQueryRepository {
@@ -22,15 +21,12 @@ public interface ChatQueryRepository {
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
void updateChatParseInfo(List<ChatParseDO> chatParseDOS);
ChatQueryDO getLastChatQuery(long chatId);
int updateChatQuery(ChatQueryDO chatQueryDO);
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses);
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryContext queryContext,
ParseResp parseResult, List<SemanticParseInfo> candidateParses);
ChatParseDO getParseInfo(Long questionId, int parseId);

View File

@@ -3,35 +3,35 @@ package com.tencent.supersonic.chat.server.persistence.repository.impl;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.server.persistence.mapper.custom.ShowCaseCustomMapper;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample.Criteria;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatParseMapper;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatQueryDOMapper;
import com.tencent.supersonic.chat.server.persistence.mapper.custom.ShowCaseCustomMapper;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.PageUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@Repository
@Primary
@@ -116,13 +116,13 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
return queryResp;
}
public Long createChatQuery(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq) {
public Long createChatQuery(ParseResp parseResult, ChatContext chatCtx, QueryContext queryContext) {
ChatQueryDO chatQueryDO = new ChatQueryDO();
chatQueryDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatQueryDO.setCreateTime(new java.util.Date());
chatQueryDO.setUserName(queryReq.getUser().getName());
chatQueryDO.setQueryText(queryReq.getQueryText());
chatQueryDO.setAgentId(queryReq.getAgentId());
chatQueryDO.setUserName(queryContext.getUser().getName());
chatQueryDO.setQueryText(queryContext.getQueryText());
chatQueryDO.setAgentId(queryContext.getAgentId());
chatQueryDO.setQueryResult("");
try {
chatQueryDOMapper.insert(chatQueryDO);
@@ -135,31 +135,24 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
@Override
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryContext queryContext,
ParseResp parseResult, List<SemanticParseInfo> candidateParses) {
Long queryId = createChatQuery(parseResult, chatCtx, queryReq);
Long queryId = createChatQuery(parseResult, chatCtx, queryContext);
List<ChatParseDO> chatParseDOList = new ArrayList<>();
getChatParseDO(chatCtx, queryReq, queryId, candidateParses, chatParseDOList);
getChatParseDO(chatCtx, queryContext, queryId, candidateParses, chatParseDOList);
if (!CollectionUtils.isEmpty(candidateParses)) {
chatParseMapper.batchSaveParseInfo(chatParseDOList);
}
return chatParseDOList;
}
@Override
public void updateChatParseInfo(List<ChatParseDO> chatParseDOS) {
for (ChatParseDO chatParseDO : chatParseDOS) {
chatParseMapper.updateParseInfo(chatParseDO);
}
}
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId,
public void getChatParseDO(ChatContext chatCtx, QueryContext queryContext, Long queryId,
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
for (int i = 0; i < parses.size(); i++) {
ChatParseDO chatParseDO = new ChatParseDO();
chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatParseDO.setQuestionId(queryId);
chatParseDO.setQueryText(queryReq.getQueryText());
chatParseDO.setQueryText(queryContext.getQueryText());
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
chatParseDO.setIsCandidate(1);
if (i == 0) {
@@ -167,7 +160,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
chatParseDO.setParseId(parses.get(i).getId());
chatParseDO.setCreateTime(new java.util.Date());
chatParseDO.setUserName(queryReq.getUser().getName());
chatParseDO.setUserName(queryContext.getUser().getName());
chatParseDOList.add(chatParseDO);
}
}

View File

@@ -1,13 +1,12 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.analytics.MetricAnalyzeQuery;
import com.tencent.supersonic.chat.server.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -29,7 +28,6 @@ public class EntityInfoProcessor implements ParseResultProcessor {
}
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
.collect(Collectors.toList());
QueryReq queryReq = queryContext.getRequest();
selectedParses.forEach(parseInfo -> {
String queryMode = parseInfo.getQueryMode();
if (QueryManager.containsPluginQuery(queryMode)
@@ -38,7 +36,7 @@ public class EntityInfoProcessor implements ParseResultProcessor {
}
//1. set entity info
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryReq.getUser());
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryContext.getUser());
if (QueryManager.isTagQuery(queryMode)
|| QueryManager.isMetricQuery(queryMode)) {
parseInfo.setEntityInfo(entityInfo);

View File

@@ -35,8 +35,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
@SneakyThrows
private void doProcess(ParseResp parseResp, QueryContext queryContext) {
Long queryId = parseResp.getQueryId();
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(queryContext.getRequest().getQueryText(),
queryContext.getRequest().getAgentId());
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(queryContext.getQueryText(),
queryContext.getAgentId());
ChatQueryDO chatQueryDO = getChatQuery(queryId);
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
updateChatQuery(chatQueryDO);

View File

@@ -1,11 +1,10 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
@@ -20,9 +19,8 @@ public class RespBuildProcessor implements ParseResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
QueryReq queryReq = queryContext.getRequest();
parseResp.setChatId(queryReq.getChatId());
parseResp.setQueryText(queryReq.getQueryText());
parseResp.setChatId(queryContext.getChatId());
parseResp.setQueryText(queryContext.getQueryText());
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
ChatService chatService = ContextUtils.getBean(ChatService.class);
if (candidateQueries.size() > 0) {
@@ -33,7 +31,7 @@ public class RespBuildProcessor implements ParseResultProcessor {
} else {
parseResp.setState(ParseResp.ParseState.FAILED);
}
chatService.batchAddParse(chatContext, queryReq, parseResp);
chatService.batchAddParse(chatContext, queryContext, parseResp);
}
}

View File

@@ -1,21 +1,20 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* SqlInfoProcessor adds S2SQL to the parsing results so that
@@ -27,7 +26,6 @@ public class SqlInfoProcessor implements ParseResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
QueryReq queryReq = queryContext.getRequest();
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(semanticQueries)) {
return;
@@ -35,26 +33,26 @@ public class SqlInfoProcessor implements ParseResultProcessor {
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
.collect(Collectors.toList());
long startTime = System.currentTimeMillis();
addSqlInfo(queryReq, selectedParses);
addSqlInfo(queryContext, selectedParses);
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime);
}
private void addSqlInfo(QueryReq queryReq, List<SemanticParseInfo> semanticParseInfos) {
private void addSqlInfo(QueryContext queryContext, List<SemanticParseInfo> semanticParseInfos) {
if (CollectionUtils.isEmpty(semanticParseInfos)) {
return;
}
semanticParseInfos.forEach(parseInfo -> {
addSqlInfo(queryReq, parseInfo);
addSqlInfo(queryContext, parseInfo);
});
}
private void addSqlInfo(QueryReq queryReq, SemanticParseInfo parseInfo) {
private void addSqlInfo(QueryContext queryContext, SemanticParseInfo parseInfo) {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (Objects.isNull(semanticQuery)) {
return;
}
semanticQuery.setParseInfo(parseInfo);
String explainSql = semanticQuery.explain(queryReq.getUser());
String explainSql = semanticQuery.explain(queryContext.getUser());
if (StringUtils.isBlank(explainSql)) {
return;
}

View File

@@ -4,11 +4,11 @@ import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
@@ -46,7 +46,7 @@ public interface ChatService {
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult);
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryContext queryContext, ParseResp parseResult);
ChatQueryDO getLastQuery(long chatId);

View File

@@ -5,11 +5,11 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
@@ -211,9 +211,9 @@ public class ChatServiceImpl implements ChatService {
}
@Override
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult) {
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryContext queryContext, ParseResp parseResult) {
List<SemanticParseInfo> candidateParses = parseResult.getSelectedParses();
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses);
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryContext, parseResult, candidateParses);
}
@Override

View File

@@ -90,6 +90,7 @@ import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Primary;
@@ -197,7 +198,6 @@ public class QueryServiceImpl implements QueryService {
List<Plugin> pluginList = pluginService.getPluginList();
QueryContext queryCtx = QueryContext.builder()
.request(queryReq)
.queryFilters(queryReq.getQueryFilters())
.semanticSchema(semanticSchema)
.candidateQueries(new ArrayList<>())
@@ -207,6 +207,7 @@ public class QueryServiceImpl implements QueryService {
.nameToPlugin(nameToPlugin)
.pluginList(pluginList)
.build();
BeanUtils.copyProperties(queryReq, queryCtx);
return queryCtx;
}

View File

@@ -42,6 +42,7 @@ import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
@@ -91,10 +92,10 @@ public class SearchServiceImpl implements SearchService {
List<Term> originals = HanlpHelper.getTerms(queryText);
log.info("hanlp parse result: {}", originals);
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq, agentService.getAgent(agentId));
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq.getModelId(), agentService.getAgent(agentId));
QueryContext queryContext = new QueryContext();
queryContext.setRequest(queryReq);
BeanUtils.copyProperties(queryReq, queryContext);
Map<MatchText, List<HanlpMapResult>> regTextMap =
searchMatchStrategy.match(queryContext, originals, detectModelIds);
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));

View File

@@ -1,11 +1,10 @@
package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.headless.api.request.MetricQueryReq;
import com.tencent.supersonic.headless.api.request.ParseSqlReq;
import com.tencent.supersonic.headless.api.request.QueryStructReq;
import com.tencent.supersonic.headless.api.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.executor.QueryExecutor;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
public interface SemantciQueryEngine {
@@ -17,5 +16,4 @@ public interface SemantciQueryEngine {
QueryStatement physicalSql(QueryStructReq queryStructCmd, ParseSqlReq sqlCommend) throws Exception;
QueryStatement physicalSql(QueryStructReq queryStructCmd, MetricQueryReq sqlCommend) throws Exception;
}

View File

@@ -1,13 +1,12 @@
package com.tencent.supersonic.headless.server.service.impl;
import com.tencent.supersonic.headless.api.request.MetricQueryReq;
import com.tencent.supersonic.headless.api.request.ParseSqlReq;
import com.tencent.supersonic.headless.api.request.QueryStructReq;
import com.tencent.supersonic.headless.api.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.executor.QueryExecutor;
import com.tencent.supersonic.headless.core.planner.QueryOptimizer;
import com.tencent.supersonic.headless.core.parser.QueryParser;
import com.tencent.supersonic.headless.core.parser.calcite.s2sql.SemanticModel;
import com.tencent.supersonic.headless.core.planner.QueryOptimizer;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.manager.SemanticSchemaManager;
@@ -82,15 +81,6 @@ public class SemantciQueryEngineImpl implements SemantciQueryEngine {
return optimize(queryStructCmd, queryParser.parser(sqlCommend, queryStatement));
}
public QueryStatement physicalSql(QueryStructReq queryStructCmd, MetricQueryReq metricCommand) throws Exception {
QueryStatement queryStatement = new QueryStatement();
queryStatement.setQueryStructReq(queryStructCmd);
queryStatement.setMetricReq(metricCommand);
queryStatement.setIsS2SQL(false);
queryStatement.setSemanticModel(getSemanticModel(queryStatement));
return queryParser.parser(queryStatement);
}
private SemanticModel getSemanticModel(QueryStatement queryStatement) throws Exception {
QueryStructReq queryStructReq = queryStatement.getQueryStructReq();
return semanticSchemaManager.get(queryStructReq.getModelIdStr());