mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) Remove QueryReq parameter from QueryContext. (#656)
This commit is contained in:
@@ -28,7 +28,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
@@ -44,7 +44,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
String text = queryContext.getQueryText();
|
||||
Set<T> results = new HashSet<>();
|
||||
|
||||
Set<String> detectSegments = new HashSet<>();
|
||||
@@ -102,7 +102,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
}
|
||||
|
||||
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest(), queryContext.getAgent());
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getModelId(), queryContext.getAgent());
|
||||
terms = filterByModelIds(terms, detectModelIds);
|
||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
|
||||
List<T> matches = new ArrayList<>();
|
||||
|
||||
@@ -55,11 +55,11 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
String detectSegment = queryContext.getRequest().getQueryText().substring(startIndex, index);
|
||||
String detectSegment = queryContext.getQueryText().substring(startIndex, index);
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
}
|
||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest(), queryContext.getAgent());
|
||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getModelId(), queryContext.getAgent());
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
//1. query from embedding by queryText
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
String queryText = queryContext.getQueryText();
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
|
||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.chat.core.knowledge.SearchService;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
@@ -38,8 +37,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String text = queryReq.getQueryText();
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
@@ -61,9 +59,8 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String text = queryReq.getQueryText();
|
||||
Integer agentId = queryReq.getAgentId();
|
||||
String text = queryContext.getQueryText();
|
||||
Integer agentId = queryContext.getAgentId();
|
||||
String detectSegment = text.substring(startIndex, index);
|
||||
|
||||
// step1. pre search
|
||||
|
||||
@@ -29,7 +29,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
String queryText = queryContext.getQueryText();
|
||||
//1.hanlpDict Match
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
||||
@@ -82,9 +81,8 @@ public class MapperHelper {
|
||||
detectSegment.length());
|
||||
}
|
||||
|
||||
public Set<Long> getModelIds(QueryReq request, Agent agent) {
|
||||
public Set<Long> getModelIds(Long modelId, Agent agent) {
|
||||
|
||||
Long modelId = request.getModelId();
|
||||
Set<Long> detectModelIds = new HashSet<>();
|
||||
if (Objects.nonNull(agent)) {
|
||||
detectModelIds = agent.getModelIds(null);
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -24,8 +23,7 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
Long modelId = queryReq.getModelId();
|
||||
Long modelId = queryContext.getModelId();
|
||||
if (modelId == null || modelId <= 0) {
|
||||
return;
|
||||
}
|
||||
@@ -62,7 +60,7 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
.name(String.valueOf(filter.getValue()))
|
||||
.type(SchemaElementType.VALUE)
|
||||
.bizName(filter.getBizName())
|
||||
.model(queryContext.getRequest().getModelId())
|
||||
.model(queryContext.getModelId())
|
||||
.build();
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.chat.core.knowledge.SearchService;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
@@ -29,8 +28,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> originals,
|
||||
Set<Long> detectModelIds) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
String text = queryReq.getQueryText();
|
||||
String text = queryContext.getQueryText();
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||
|
||||
List<Integer> detectIndexList = Lists.newArrayList();
|
||||
@@ -54,9 +52,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||
List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment,
|
||||
SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
||||
SearchService.SEARCH_SIZE, queryContext.getAgentId(), detectModelIds);
|
||||
List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(
|
||||
detectSegment, SEARCH_SIZE, queryReq.getAgentId(), detectModelIds);
|
||||
detectSegment, SEARCH_SIZE, queryContext.getAgentId(), detectModelIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||
|
||||
@@ -31,7 +31,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
|
||||
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
|
||||
User user = queryContext.getRequest().getUser();
|
||||
User user = queryContext.getUser();
|
||||
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
// 1.init S2SQL
|
||||
|
||||
@@ -23,7 +23,7 @@ public class SatisfactionChecker {
|
||||
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||
continue;
|
||||
}
|
||||
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
|
||||
if (checkThreshold(queryContext.getQueryText(), query.getParseInfo())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,30 +2,29 @@ package com.tencent.supersonic.chat.core.parser.plugin;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
|
||||
/**
|
||||
@@ -36,7 +35,7 @@ public abstract class PluginParser implements SemanticParser {
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||
if (queryContext.getRequest().getQueryText().length() <= semanticQuery.getParseInfo().getScore()
|
||||
if (queryContext.getQueryText().length() <= semanticQuery.getParseInfo().getScore()
|
||||
&& (QueryManager.getPluginQueryModes().contains(semanticQuery.getQueryMode()))) {
|
||||
return;
|
||||
}
|
||||
@@ -64,8 +63,7 @@ public abstract class PluginParser implements SemanticParser {
|
||||
for (Long modelId : modelIds) {
|
||||
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin,
|
||||
queryContext.getRequest(),
|
||||
queryContext.getModelClusterMapInfo().getMatchedElements(modelId),
|
||||
queryContext.getQueryFilters(), queryContext.getModelClusterMapInfo().getMatchedElements(modelId),
|
||||
pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
@@ -78,7 +76,7 @@ public abstract class PluginParser implements SemanticParser {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext);
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryReq queryReq,
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryFilters queryFilters,
|
||||
List<SchemaElementMatch> schemaElementMatches, double distance) {
|
||||
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
|
||||
modelId = plugin.getModelList().get(0);
|
||||
@@ -92,7 +90,7 @@ public abstract class PluginParser implements SemanticParser {
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||
pluginParseResult.setPlugin(plugin);
|
||||
pluginParseResult.setRequest(queryReq);
|
||||
pluginParseResult.setQueryFilters(queryFilters);
|
||||
pluginParseResult.setDistance(distance);
|
||||
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||
properties.put("type", "plugin");
|
||||
|
||||
@@ -42,7 +42,7 @@ public class EmbeddingRecallParser extends PluginParser {
|
||||
|
||||
@Override
|
||||
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
String text = queryContext.getQueryText();
|
||||
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
return null;
|
||||
@@ -63,7 +63,7 @@ public class EmbeddingRecallParser extends PluginParser {
|
||||
}
|
||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||
double distance = embeddingRetrieval.getDistance();
|
||||
double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
|
||||
double score = queryContext.getQueryText().length() * (1 - distance);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
|
||||
}
|
||||
|
||||
@@ -33,7 +33,7 @@ public class FunctionCallParser extends PluginParser {
|
||||
String functionUrl = functionCallConfig.getUrl();
|
||||
if (StringUtils.isBlank(functionUrl) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
|
||||
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
|
||||
queryContext.getRequest().getQueryText());
|
||||
queryContext.getQueryText());
|
||||
return false;
|
||||
}
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
@@ -60,7 +60,7 @@ public class FunctionCallParser extends PluginParser {
|
||||
if (CollectionUtils.isEmpty(modelList)) {
|
||||
return null;
|
||||
}
|
||||
double score = queryContext.getRequest().getQueryText().length();
|
||||
double score = queryContext.getQueryText().length();
|
||||
return PluginRecallResult.builder().plugin(plugin).modelIds(modelList).score(score).build();
|
||||
}
|
||||
return null;
|
||||
@@ -68,7 +68,7 @@ public class FunctionCallParser extends PluginParser {
|
||||
|
||||
public FunctionResp functionCall(QueryContext queryContext) {
|
||||
List<PluginParseConfig> pluginToFunctionCall =
|
||||
getPluginToFunctionCall(queryContext.getRequest().getModelId(), queryContext);
|
||||
getPluginToFunctionCall(queryContext.getModelId(), queryContext);
|
||||
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
|
||||
log.info("function call parser, plugin is empty, skip");
|
||||
return null;
|
||||
@@ -78,7 +78,7 @@ public class FunctionCallParser extends PluginParser {
|
||||
functionResp.setToolSelection(pluginToFunctionCall.iterator().next().getName());
|
||||
} else {
|
||||
FunctionReq functionReq = FunctionReq.builder()
|
||||
.queryText(queryContext.getRequest().getQueryText())
|
||||
.queryText(queryContext.getQueryText())
|
||||
.pluginConfigs(pluginToFunctionCall).build();
|
||||
functionResp = ComponentFactory.getLLMProxy().requestFunction(functionReq);
|
||||
}
|
||||
|
||||
@@ -116,7 +116,7 @@ public class HeuristicModelResolver implements ModelResolver {
|
||||
public String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
|
||||
SchemaModelClusterMapInfo mapInfo = queryContext.getModelClusterMapInfo();
|
||||
Set<String> matchedModelClusters = mapInfo.getElementMatchesByModelIds(restrictiveModels).keySet();
|
||||
Long modelId = queryContext.getRequest().getModelId();
|
||||
Long modelId = queryContext.getModelId();
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
if (CollectionUtils.isEmpty(restrictiveModels) || restrictiveModels.contains(modelId)) {
|
||||
return getModelClusterByModelId(modelId, matchedModelClusters);
|
||||
|
||||
@@ -57,7 +57,7 @@ public class LLMRequestService {
|
||||
return true;
|
||||
}
|
||||
if (SatisfactionChecker.isSkip(queryCtx)) {
|
||||
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getRequest().getQueryText());
|
||||
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getQueryText());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@@ -121,7 +121,7 @@ public class LLMRequestService {
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
|
||||
ModelCluster modelCluster, List<ElementValue> linkingValues) {
|
||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||
String queryText = queryCtx.getRequest().getQueryText();
|
||||
String queryText = queryCtx.getQueryText();
|
||||
|
||||
LLMReq llmReq = new LLMReq();
|
||||
llmReq.setQueryText(queryText);
|
||||
|
||||
@@ -39,7 +39,7 @@ public class LLMResponseService {
|
||||
properties.put("name", commonAgentTool.getName());
|
||||
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
|
||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setS2SQL(s2SQL);
|
||||
parseInfo.setModel(parseResult.getModelCluster());
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
@@ -24,7 +23,6 @@ public class LLMSqlParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
||||
QueryReq request = queryCtx.getRequest();
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
//1.determine whether to skip this parser.
|
||||
if (requestService.isSkip(queryCtx)) {
|
||||
@@ -56,7 +54,6 @@ public class LLMSqlParser implements SemanticParser {
|
||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
|
||||
ParseResult parseResult = ParseResult.builder()
|
||||
.request(request)
|
||||
.modelCluster(modelCluster)
|
||||
.commonAgentTool(commonAgentTool)
|
||||
.llmReq(llmReq)
|
||||
|
||||
@@ -49,7 +49,7 @@ public class AggregateTypeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
String queryText = queryContext.getQueryText();
|
||||
AggregateConf aggregateConf = resolveAggregateConf(queryText);
|
||||
|
||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||
|
||||
@@ -42,12 +42,13 @@ public class TimeRangeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
DateConf dateConf = parseRecent(queryContext.getRequest().getQueryText());
|
||||
String queryText = queryContext.getQueryText();
|
||||
DateConf dateConf = parseRecent(queryText);
|
||||
if (dateConf == null) {
|
||||
dateConf = parseDateNumber(queryContext.getRequest().getQueryText());
|
||||
dateConf = parseDateNumber(queryText);
|
||||
}
|
||||
if (dateConf == null) {
|
||||
dateConf = parseDateCN(queryContext.getRequest().getQueryText());
|
||||
dateConf = parseDateCN(queryText);
|
||||
}
|
||||
|
||||
if (dateConf != null) {
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.chat.core.plugin;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class PluginParseResult {
|
||||
|
||||
private Plugin plugin;
|
||||
private QueryReq request;
|
||||
private QueryFilters queryFilters;
|
||||
private double distance;
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.tencent.supersonic.chat.core.pojo;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
@@ -27,7 +27,12 @@ import lombok.NoArgsConstructor;
|
||||
@AllArgsConstructor
|
||||
public class QueryContext {
|
||||
|
||||
private QueryReq request;
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Long modelId;
|
||||
private User user;
|
||||
private boolean saveAnswer = true;
|
||||
private Integer agentId;
|
||||
private QueryFilters queryFilters;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
|
||||
@@ -7,7 +7,6 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.core.query.BaseSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -31,11 +30,10 @@ public abstract class PluginSemanticQuery extends BaseSemanticQuery {
|
||||
|
||||
private Map<Long, Object> getFilterMap(PluginParseResult pluginParseResult) {
|
||||
Map<Long, Object> map = new HashMap<>();
|
||||
QueryReq queryReq = pluginParseResult.getRequest();
|
||||
if (queryReq == null || queryReq.getQueryFilters() == null) {
|
||||
QueryFilters queryFilters = pluginParseResult.getQueryFilters();
|
||||
if (queryFilters == null) {
|
||||
return map;
|
||||
}
|
||||
QueryFilters queryFilters = queryReq.getQueryFilters();
|
||||
List<QueryFilter> queryFilterList = queryFilters.getFilters();
|
||||
if (CollectionUtils.isEmpty(queryFilterList)) {
|
||||
return map;
|
||||
|
||||
@@ -36,7 +36,7 @@ public class MetricTopNQuery extends MetricSemanticQuery {
|
||||
@Override
|
||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
||||
QueryContext queryCtx) {
|
||||
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getRequest().getQueryText());
|
||||
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getQueryText());
|
||||
if (matcher.matches()) {
|
||||
return super.match(candidateElementMatches, queryCtx);
|
||||
}
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ChatQueryRepository {
|
||||
@@ -22,15 +21,12 @@ public interface ChatQueryRepository {
|
||||
|
||||
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
|
||||
|
||||
void updateChatParseInfo(List<ChatParseDO> chatParseDOS);
|
||||
|
||||
ChatQueryDO getLastChatQuery(long chatId);
|
||||
|
||||
int updateChatQuery(ChatQueryDO chatQueryDO);
|
||||
|
||||
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
|
||||
ParseResp parseResult,
|
||||
List<SemanticParseInfo> candidateParses);
|
||||
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryContext queryContext,
|
||||
ParseResp parseResult, List<SemanticParseInfo> candidateParses);
|
||||
|
||||
ChatParseDO getParseInfo(Long questionId, int parseId);
|
||||
|
||||
|
||||
@@ -3,35 +3,35 @@ package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.github.pagehelper.PageHelper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.custom.ShowCaseCustomMapper;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample.Criteria;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.ChatParseMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.ChatQueryDOMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.custom.ShowCaseCustomMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.PageUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Repository;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Repository
|
||||
@Primary
|
||||
@@ -116,13 +116,13 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
return queryResp;
|
||||
}
|
||||
|
||||
public Long createChatQuery(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq) {
|
||||
public Long createChatQuery(ParseResp parseResult, ChatContext chatCtx, QueryContext queryContext) {
|
||||
ChatQueryDO chatQueryDO = new ChatQueryDO();
|
||||
chatQueryDO.setChatId(Long.valueOf(chatCtx.getChatId()));
|
||||
chatQueryDO.setCreateTime(new java.util.Date());
|
||||
chatQueryDO.setUserName(queryReq.getUser().getName());
|
||||
chatQueryDO.setQueryText(queryReq.getQueryText());
|
||||
chatQueryDO.setAgentId(queryReq.getAgentId());
|
||||
chatQueryDO.setUserName(queryContext.getUser().getName());
|
||||
chatQueryDO.setQueryText(queryContext.getQueryText());
|
||||
chatQueryDO.setAgentId(queryContext.getAgentId());
|
||||
chatQueryDO.setQueryResult("");
|
||||
try {
|
||||
chatQueryDOMapper.insert(chatQueryDO);
|
||||
@@ -135,31 +135,24 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
|
||||
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryContext queryContext,
|
||||
ParseResp parseResult, List<SemanticParseInfo> candidateParses) {
|
||||
Long queryId = createChatQuery(parseResult, chatCtx, queryReq);
|
||||
Long queryId = createChatQuery(parseResult, chatCtx, queryContext);
|
||||
List<ChatParseDO> chatParseDOList = new ArrayList<>();
|
||||
getChatParseDO(chatCtx, queryReq, queryId, candidateParses, chatParseDOList);
|
||||
getChatParseDO(chatCtx, queryContext, queryId, candidateParses, chatParseDOList);
|
||||
if (!CollectionUtils.isEmpty(candidateParses)) {
|
||||
chatParseMapper.batchSaveParseInfo(chatParseDOList);
|
||||
}
|
||||
return chatParseDOList;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateChatParseInfo(List<ChatParseDO> chatParseDOS) {
|
||||
for (ChatParseDO chatParseDO : chatParseDOS) {
|
||||
chatParseMapper.updateParseInfo(chatParseDO);
|
||||
}
|
||||
}
|
||||
|
||||
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId,
|
||||
public void getChatParseDO(ChatContext chatCtx, QueryContext queryContext, Long queryId,
|
||||
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
|
||||
for (int i = 0; i < parses.size(); i++) {
|
||||
ChatParseDO chatParseDO = new ChatParseDO();
|
||||
chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId()));
|
||||
chatParseDO.setQuestionId(queryId);
|
||||
chatParseDO.setQueryText(queryReq.getQueryText());
|
||||
chatParseDO.setQueryText(queryContext.getQueryText());
|
||||
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
|
||||
chatParseDO.setIsCandidate(1);
|
||||
if (i == 0) {
|
||||
@@ -167,7 +160,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
}
|
||||
chatParseDO.setParseId(parses.get(i).getId());
|
||||
chatParseDO.setCreateTime(new java.util.Date());
|
||||
chatParseDO.setUserName(queryReq.getUser().getName());
|
||||
chatParseDO.setUserName(queryContext.getUser().getName());
|
||||
chatParseDOList.add(chatParseDO);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.llm.analytics.MetricAnalyzeQuery;
|
||||
import com.tencent.supersonic.chat.server.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -29,7 +28,6 @@ public class EntityInfoProcessor implements ParseResultProcessor {
|
||||
}
|
||||
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
|
||||
.collect(Collectors.toList());
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
selectedParses.forEach(parseInfo -> {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
if (QueryManager.containsPluginQuery(queryMode)
|
||||
@@ -38,7 +36,7 @@ public class EntityInfoProcessor implements ParseResultProcessor {
|
||||
}
|
||||
//1. set entity info
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryReq.getUser());
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryContext.getUser());
|
||||
if (QueryManager.isTagQuery(queryMode)
|
||||
|| QueryManager.isMetricQuery(queryMode)) {
|
||||
parseInfo.setEntityInfo(entityInfo);
|
||||
|
||||
@@ -35,8 +35,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
@SneakyThrows
|
||||
private void doProcess(ParseResp parseResp, QueryContext queryContext) {
|
||||
Long queryId = parseResp.getQueryId();
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(queryContext.getRequest().getQueryText(),
|
||||
queryContext.getRequest().getAgentId());
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(queryContext.getQueryText(),
|
||||
queryContext.getAgentId());
|
||||
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
||||
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
|
||||
updateChatQuery(chatQueryDO);
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.List;
|
||||
@@ -20,9 +19,8 @@ public class RespBuildProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
parseResp.setChatId(queryReq.getChatId());
|
||||
parseResp.setQueryText(queryReq.getQueryText());
|
||||
parseResp.setChatId(queryContext.getChatId());
|
||||
parseResp.setQueryText(queryContext.getQueryText());
|
||||
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
|
||||
ChatService chatService = ContextUtils.getBean(ChatService.class);
|
||||
if (candidateQueries.size() > 0) {
|
||||
@@ -33,7 +31,7 @@ public class RespBuildProcessor implements ParseResultProcessor {
|
||||
} else {
|
||||
parseResp.setState(ParseResp.ParseState.FAILED);
|
||||
}
|
||||
chatService.batchAddParse(chatContext, queryReq, parseResp);
|
||||
chatService.batchAddParse(chatContext, queryContext, parseResp);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* SqlInfoProcessor adds S2SQL to the parsing results so that
|
||||
@@ -27,7 +26,6 @@ public class SqlInfoProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
|
||||
QueryReq queryReq = queryContext.getRequest();
|
||||
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
|
||||
if (CollectionUtils.isEmpty(semanticQueries)) {
|
||||
return;
|
||||
@@ -35,26 +33,26 @@ public class SqlInfoProcessor implements ParseResultProcessor {
|
||||
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
|
||||
.collect(Collectors.toList());
|
||||
long startTime = System.currentTimeMillis();
|
||||
addSqlInfo(queryReq, selectedParses);
|
||||
addSqlInfo(queryContext, selectedParses);
|
||||
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime);
|
||||
}
|
||||
|
||||
private void addSqlInfo(QueryReq queryReq, List<SemanticParseInfo> semanticParseInfos) {
|
||||
private void addSqlInfo(QueryContext queryContext, List<SemanticParseInfo> semanticParseInfos) {
|
||||
if (CollectionUtils.isEmpty(semanticParseInfos)) {
|
||||
return;
|
||||
}
|
||||
semanticParseInfos.forEach(parseInfo -> {
|
||||
addSqlInfo(queryReq, parseInfo);
|
||||
addSqlInfo(queryContext, parseInfo);
|
||||
});
|
||||
}
|
||||
|
||||
private void addSqlInfo(QueryReq queryReq, SemanticParseInfo parseInfo) {
|
||||
private void addSqlInfo(QueryContext queryContext, SemanticParseInfo parseInfo) {
|
||||
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
|
||||
if (Objects.isNull(semanticQuery)) {
|
||||
return;
|
||||
}
|
||||
semanticQuery.setParseInfo(parseInfo);
|
||||
String explainSql = semanticQuery.explain(queryReq.getUser());
|
||||
String explainSql = semanticQuery.explain(queryContext.getUser());
|
||||
if (StringUtils.isBlank(explainSql)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
@@ -46,7 +46,7 @@ public interface ChatService {
|
||||
|
||||
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
|
||||
|
||||
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult);
|
||||
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryContext queryContext, ParseResp parseResult);
|
||||
|
||||
ChatQueryDO getLastQuery(long chatId);
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
@@ -211,9 +211,9 @@ public class ChatServiceImpl implements ChatService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult) {
|
||||
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryContext queryContext, ParseResp parseResult) {
|
||||
List<SemanticParseInfo> candidateParses = parseResult.getSelectedParses();
|
||||
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses);
|
||||
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryContext, parseResult, candidateParses);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -90,6 +90,7 @@ import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
@@ -197,7 +198,6 @@ public class QueryServiceImpl implements QueryService {
|
||||
List<Plugin> pluginList = pluginService.getPluginList();
|
||||
|
||||
QueryContext queryCtx = QueryContext.builder()
|
||||
.request(queryReq)
|
||||
.queryFilters(queryReq.getQueryFilters())
|
||||
.semanticSchema(semanticSchema)
|
||||
.candidateQueries(new ArrayList<>())
|
||||
@@ -207,6 +207,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
.nameToPlugin(nameToPlugin)
|
||||
.pluginList(pluginList)
|
||||
.build();
|
||||
BeanUtils.copyProperties(queryReq, queryCtx);
|
||||
return queryCtx;
|
||||
}
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.stereotype.Service;
|
||||
@@ -91,10 +92,10 @@ public class SearchServiceImpl implements SearchService {
|
||||
List<Term> originals = HanlpHelper.getTerms(queryText);
|
||||
log.info("hanlp parse result: {}", originals);
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq, agentService.getAgent(agentId));
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq.getModelId(), agentService.getAgent(agentId));
|
||||
|
||||
QueryContext queryContext = new QueryContext();
|
||||
queryContext.setRequest(queryReq);
|
||||
BeanUtils.copyProperties(queryReq, queryContext);
|
||||
Map<MatchText, List<HanlpMapResult>> regTextMap =
|
||||
searchMatchStrategy.match(queryContext, originals, detectModelIds);
|
||||
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package com.tencent.supersonic.headless.server.service;
|
||||
|
||||
import com.tencent.supersonic.headless.api.request.MetricQueryReq;
|
||||
import com.tencent.supersonic.headless.api.request.ParseSqlReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
|
||||
import com.tencent.supersonic.headless.core.executor.QueryExecutor;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
|
||||
|
||||
public interface SemantciQueryEngine {
|
||||
|
||||
@@ -17,5 +16,4 @@ public interface SemantciQueryEngine {
|
||||
|
||||
QueryStatement physicalSql(QueryStructReq queryStructCmd, ParseSqlReq sqlCommend) throws Exception;
|
||||
|
||||
QueryStatement physicalSql(QueryStructReq queryStructCmd, MetricQueryReq sqlCommend) throws Exception;
|
||||
}
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package com.tencent.supersonic.headless.server.service.impl;
|
||||
|
||||
import com.tencent.supersonic.headless.api.request.MetricQueryReq;
|
||||
import com.tencent.supersonic.headless.api.request.ParseSqlReq;
|
||||
import com.tencent.supersonic.headless.api.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.core.executor.QueryExecutor;
|
||||
import com.tencent.supersonic.headless.core.planner.QueryOptimizer;
|
||||
import com.tencent.supersonic.headless.core.parser.QueryParser;
|
||||
import com.tencent.supersonic.headless.core.parser.calcite.s2sql.SemanticModel;
|
||||
import com.tencent.supersonic.headless.core.planner.QueryOptimizer;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
|
||||
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.headless.server.manager.SemanticSchemaManager;
|
||||
@@ -82,15 +81,6 @@ public class SemantciQueryEngineImpl implements SemantciQueryEngine {
|
||||
return optimize(queryStructCmd, queryParser.parser(sqlCommend, queryStatement));
|
||||
}
|
||||
|
||||
public QueryStatement physicalSql(QueryStructReq queryStructCmd, MetricQueryReq metricCommand) throws Exception {
|
||||
QueryStatement queryStatement = new QueryStatement();
|
||||
queryStatement.setQueryStructReq(queryStructCmd);
|
||||
queryStatement.setMetricReq(metricCommand);
|
||||
queryStatement.setIsS2SQL(false);
|
||||
queryStatement.setSemanticModel(getSemanticModel(queryStatement));
|
||||
return queryParser.parser(queryStatement);
|
||||
}
|
||||
|
||||
private SemanticModel getSemanticModel(QueryStatement queryStatement) throws Exception {
|
||||
QueryStructReq queryStructReq = queryStatement.getQueryStructReq();
|
||||
return semanticSchemaManager.get(queryStructReq.getModelIdStr());
|
||||
|
||||
Reference in New Issue
Block a user