(improvement)(chat) Remove QueryReq parameter from QueryContext. (#656)

This commit is contained in:
lexluo09
2024-01-19 16:17:31 +08:00
committed by GitHub
parent f017f41201
commit cbf38ed785
35 changed files with 115 additions and 152 deletions

View File

@@ -28,7 +28,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
@Override @Override
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) { public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
String text = queryContext.getRequest().getQueryText(); String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null; return null;
} }
@@ -44,7 +44,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) { public List<T> detect(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms); Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
String text = queryContext.getRequest().getQueryText(); String text = queryContext.getQueryText();
Set<T> results = new HashSet<>(); Set<T> results = new HashSet<>();
Set<String> detectSegments = new HashSet<>(); Set<String> detectSegments = new HashSet<>();
@@ -102,7 +102,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
} }
public List<T> getMatches(QueryContext queryContext, List<Term> terms) { public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest(), queryContext.getAgent()); Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getModelId(), queryContext.getAgent());
terms = filterByModelIds(terms, detectModelIds); terms = filterByModelIds(terms, detectModelIds);
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds); Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
List<T> matches = new ArrayList<>(); List<T> matches = new ArrayList<>();

View File

@@ -55,11 +55,11 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectModelIds, public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectModelIds,
Integer startIndex, Integer index, int offset) { Integer startIndex, Integer index, int offset) {
String detectSegment = queryContext.getRequest().getQueryText().substring(startIndex, index); String detectSegment = queryContext.getQueryText().substring(startIndex, index);
if (StringUtils.isBlank(detectSegment)) { if (StringUtils.isBlank(detectSegment)) {
return; return;
} }
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getRequest(), queryContext.getAgent()); Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getModelId(), queryContext.getAgent());
Double metricDimensionThresholdConfig = getThreshold(queryContext); Double metricDimensionThresholdConfig = getThreshold(queryContext);

View File

@@ -23,7 +23,7 @@ public class EmbeddingMapper extends BaseMapper {
@Override @Override
public void doMap(QueryContext queryContext) { public void doMap(QueryContext queryContext) {
//1. query from embedding by queryText //1. query from embedding by queryText
String queryText = queryContext.getRequest().getQueryText(); String queryText = queryContext.getQueryText();
List<Term> terms = HanlpHelper.getTerms(queryText); List<Term> terms = HanlpHelper.getTerms(queryText);
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class); EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);

View File

@@ -1,11 +1,10 @@
package com.tencent.supersonic.chat.core.mapper; package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.seg.common.Term; import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.config.OptimizationConfig; import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult; import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.chat.core.knowledge.SearchService; import com.tencent.supersonic.chat.core.knowledge.SearchService;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
@@ -38,8 +37,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Override @Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> terms, public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> terms,
Set<Long> detectModelIds) { Set<Long> detectModelIds) {
QueryReq queryReq = queryContext.getRequest(); String text = queryContext.getQueryText();
String text = queryReq.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null; return null;
} }
@@ -61,9 +59,8 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds, public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
Integer startIndex, Integer index, int offset) { Integer startIndex, Integer index, int offset) {
QueryReq queryReq = queryContext.getRequest(); String text = queryContext.getQueryText();
String text = queryReq.getQueryText(); Integer agentId = queryContext.getAgentId();
Integer agentId = queryReq.getAgentId();
String detectSegment = text.substring(startIndex, index); String detectSegment = text.substring(startIndex, index);
// step1. pre search // step1. pre search

View File

@@ -29,7 +29,7 @@ public class KeywordMapper extends BaseMapper {
@Override @Override
public void doMap(QueryContext queryContext) { public void doMap(QueryContext queryContext) {
String queryText = queryContext.getRequest().getQueryText(); String queryText = queryContext.getQueryText();
//1.hanlpDict Match //1.hanlpDict Match
List<Term> terms = HanlpHelper.getTerms(queryText); List<Term> terms = HanlpHelper.getTerms(queryText);
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class); HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.core.mapper;
import com.hankcs.hanlp.algorithm.EditDistance; import com.hankcs.hanlp.algorithm.EditDistance;
import com.hankcs.hanlp.seg.common.Term; import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.agent.Agent; import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.config.OptimizationConfig; import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.chat.core.utils.NatureHelper; import com.tencent.supersonic.chat.core.utils.NatureHelper;
@@ -82,9 +81,8 @@ public class MapperHelper {
detectSegment.length()); detectSegment.length());
} }
public Set<Long> getModelIds(QueryReq request, Agent agent) { public Set<Long> getModelIds(Long modelId, Agent agent) {
Long modelId = request.getModelId();
Set<Long> detectModelIds = new HashSet<>(); Set<Long> detectModelIds = new HashSet<>();
if (Objects.nonNull(agent)) { if (Objects.nonNull(agent)) {
detectModelIds = agent.getModelIds(null); detectModelIds = agent.getModelIds(null);

View File

@@ -1,15 +1,14 @@
package com.tencent.supersonic.chat.core.mapper; package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder; import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@@ -24,8 +23,7 @@ public class QueryFilterMapper implements SchemaMapper {
@Override @Override
public void map(QueryContext queryContext) { public void map(QueryContext queryContext) {
QueryReq queryReq = queryContext.getRequest(); Long modelId = queryContext.getModelId();
Long modelId = queryReq.getModelId();
if (modelId == null || modelId <= 0) { if (modelId == null || modelId <= 0) {
return; return;
} }
@@ -62,7 +60,7 @@ public class QueryFilterMapper implements SchemaMapper {
.name(String.valueOf(filter.getValue())) .name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE) .type(SchemaElementType.VALUE)
.bizName(filter.getBizName()) .bizName(filter.getBizName())
.model(queryContext.getRequest().getModelId()) .model(queryContext.getModelId())
.build(); .build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element) .element(element)

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term; import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult; import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.chat.core.knowledge.SearchService; import com.tencent.supersonic.chat.core.knowledge.SearchService;
import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.pojo.QueryContext;
@@ -29,8 +28,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Override @Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> originals, public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<Term> originals,
Set<Long> detectModelIds) { Set<Long> detectModelIds) {
QueryReq queryReq = queryContext.getRequest(); String text = queryContext.getQueryText();
String text = queryReq.getQueryText();
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals); Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
List<Integer> detectIndexList = Lists.newArrayList(); List<Integer> detectIndexList = Lists.newArrayList();
@@ -54,9 +52,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
if (StringUtils.isNotEmpty(detectSegment)) { if (StringUtils.isNotEmpty(detectSegment)) {
List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment,
SearchService.SEARCH_SIZE, queryReq.getAgentId(), detectModelIds); SearchService.SEARCH_SIZE, queryContext.getAgentId(), detectModelIds);
List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch( List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(
detectSegment, SEARCH_SIZE, queryReq.getAgentId(), detectModelIds); detectSegment, SEARCH_SIZE, queryContext.getAgentId(), detectModelIds);
hanlpMapResults.addAll(suffixHanlpMapResults); hanlpMapResults.addAll(suffixHanlpMapResults);
// remove entity name where search // remove entity name where search
hanlpMapResults = hanlpMapResults.stream().filter(entry -> { hanlpMapResults = hanlpMapResults.stream().filter(entry -> {

View File

@@ -31,7 +31,7 @@ public class QueryTypeParser implements SemanticParser {
public void parse(QueryContext queryContext, ChatContext chatContext) { public void parse(QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries(); List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
User user = queryContext.getRequest().getUser(); User user = queryContext.getUser();
for (SemanticQuery semanticQuery : candidateQueries) { for (SemanticQuery semanticQuery : candidateQueries) {
// 1.init S2SQL // 1.init S2SQL

View File

@@ -23,7 +23,7 @@ public class SatisfactionChecker {
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) { if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
continue; continue;
} }
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) { if (checkThreshold(queryContext.getQueryText(), query.getParseInfo())) {
return true; return true;
} }
} }

View File

@@ -2,30 +2,29 @@ package com.tencent.supersonic.chat.core.parser.plugin;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.plugin.Plugin; import com.tencent.supersonic.chat.core.plugin.Plugin;
import com.tencent.supersonic.chat.core.plugin.PluginManager; import com.tencent.supersonic.chat.core.plugin.PluginManager;
import com.tencent.supersonic.chat.core.plugin.PluginParseResult; import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult; import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery; import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager; import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.ModelCluster; import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import org.springframework.util.CollectionUtils;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.util.CollectionUtils;
/** /**
@@ -36,7 +35,7 @@ public abstract class PluginParser implements SemanticParser {
@Override @Override
public void parse(QueryContext queryContext, ChatContext chatContext) { public void parse(QueryContext queryContext, ChatContext chatContext) {
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) { for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
if (queryContext.getRequest().getQueryText().length() <= semanticQuery.getParseInfo().getScore() if (queryContext.getQueryText().length() <= semanticQuery.getParseInfo().getScore()
&& (QueryManager.getPluginQueryModes().contains(semanticQuery.getQueryMode()))) { && (QueryManager.getPluginQueryModes().contains(semanticQuery.getQueryMode()))) {
return; return;
} }
@@ -64,8 +63,7 @@ public abstract class PluginParser implements SemanticParser {
for (Long modelId : modelIds) { for (Long modelId : modelIds) {
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType()); PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin, SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin,
queryContext.getRequest(), queryContext.getQueryFilters(), queryContext.getModelClusterMapInfo().getMatchedElements(modelId),
queryContext.getModelClusterMapInfo().getMatchedElements(modelId),
pluginRecallResult.getDistance()); pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode()); semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
semanticParseInfo.setScore(pluginRecallResult.getScore()); semanticParseInfo.setScore(pluginRecallResult.getScore());
@@ -78,8 +76,8 @@ public abstract class PluginParser implements SemanticParser {
return PluginManager.getPluginAgentCanSupport(queryContext); return PluginManager.getPluginAgentCanSupport(queryContext);
} }
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryReq queryReq, protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryFilters queryFilters,
List<SchemaElementMatch> schemaElementMatches, double distance) { List<SchemaElementMatch> schemaElementMatches, double distance) {
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) { if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
modelId = plugin.getModelList().get(0); modelId = plugin.getModelList().get(0);
} }
@@ -92,7 +90,7 @@ public abstract class PluginParser implements SemanticParser {
Map<String, Object> properties = new HashMap<>(); Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult(); PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin); pluginParseResult.setPlugin(plugin);
pluginParseResult.setRequest(queryReq); pluginParseResult.setQueryFilters(queryFilters);
pluginParseResult.setDistance(distance); pluginParseResult.setDistance(distance);
properties.put(Constants.CONTEXT, pluginParseResult); properties.put(Constants.CONTEXT, pluginParseResult);
properties.put("type", "plugin"); properties.put("type", "plugin");

View File

@@ -42,7 +42,7 @@ public class EmbeddingRecallParser extends PluginParser {
@Override @Override
public PluginRecallResult recallPlugin(QueryContext queryContext) { public PluginRecallResult recallPlugin(QueryContext queryContext) {
String text = queryContext.getRequest().getQueryText(); String text = queryContext.getQueryText();
List<Retrieval> embeddingRetrievals = embeddingRecall(text); List<Retrieval> embeddingRetrievals = embeddingRecall(text);
if (CollectionUtils.isEmpty(embeddingRetrievals)) { if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return null; return null;
@@ -63,7 +63,7 @@ public class EmbeddingRecallParser extends PluginParser {
} }
plugin.setParseMode(ParseMode.EMBEDDING_RECALL); plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double distance = embeddingRetrieval.getDistance(); double distance = embeddingRetrieval.getDistance();
double score = queryContext.getRequest().getQueryText().length() * (1 - distance); double score = queryContext.getQueryText().length() * (1 - distance);
return PluginRecallResult.builder() return PluginRecallResult.builder()
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build(); .plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
} }

View File

@@ -33,7 +33,7 @@ public class FunctionCallParser extends PluginParser {
String functionUrl = functionCallConfig.getUrl(); String functionUrl = functionCallConfig.getUrl();
if (StringUtils.isBlank(functionUrl) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) { if (StringUtils.isBlank(functionUrl) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl, log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
queryContext.getRequest().getQueryText()); queryContext.getQueryText());
return false; return false;
} }
List<Plugin> plugins = getPluginList(queryContext); List<Plugin> plugins = getPluginList(queryContext);
@@ -60,7 +60,7 @@ public class FunctionCallParser extends PluginParser {
if (CollectionUtils.isEmpty(modelList)) { if (CollectionUtils.isEmpty(modelList)) {
return null; return null;
} }
double score = queryContext.getRequest().getQueryText().length(); double score = queryContext.getQueryText().length();
return PluginRecallResult.builder().plugin(plugin).modelIds(modelList).score(score).build(); return PluginRecallResult.builder().plugin(plugin).modelIds(modelList).score(score).build();
} }
return null; return null;
@@ -68,7 +68,7 @@ public class FunctionCallParser extends PluginParser {
public FunctionResp functionCall(QueryContext queryContext) { public FunctionResp functionCall(QueryContext queryContext) {
List<PluginParseConfig> pluginToFunctionCall = List<PluginParseConfig> pluginToFunctionCall =
getPluginToFunctionCall(queryContext.getRequest().getModelId(), queryContext); getPluginToFunctionCall(queryContext.getModelId(), queryContext);
if (CollectionUtils.isEmpty(pluginToFunctionCall)) { if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
log.info("function call parser, plugin is empty, skip"); log.info("function call parser, plugin is empty, skip");
return null; return null;
@@ -78,7 +78,7 @@ public class FunctionCallParser extends PluginParser {
functionResp.setToolSelection(pluginToFunctionCall.iterator().next().getName()); functionResp.setToolSelection(pluginToFunctionCall.iterator().next().getName());
} else { } else {
FunctionReq functionReq = FunctionReq.builder() FunctionReq functionReq = FunctionReq.builder()
.queryText(queryContext.getRequest().getQueryText()) .queryText(queryContext.getQueryText())
.pluginConfigs(pluginToFunctionCall).build(); .pluginConfigs(pluginToFunctionCall).build();
functionResp = ComponentFactory.getLLMProxy().requestFunction(functionReq); functionResp = ComponentFactory.getLLMProxy().requestFunction(functionReq);
} }

View File

@@ -116,7 +116,7 @@ public class HeuristicModelResolver implements ModelResolver {
public String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) { public String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
SchemaModelClusterMapInfo mapInfo = queryContext.getModelClusterMapInfo(); SchemaModelClusterMapInfo mapInfo = queryContext.getModelClusterMapInfo();
Set<String> matchedModelClusters = mapInfo.getElementMatchesByModelIds(restrictiveModels).keySet(); Set<String> matchedModelClusters = mapInfo.getElementMatchesByModelIds(restrictiveModels).keySet();
Long modelId = queryContext.getRequest().getModelId(); Long modelId = queryContext.getModelId();
if (Objects.nonNull(modelId) && modelId > 0) { if (Objects.nonNull(modelId) && modelId > 0) {
if (CollectionUtils.isEmpty(restrictiveModels) || restrictiveModels.contains(modelId)) { if (CollectionUtils.isEmpty(restrictiveModels) || restrictiveModels.contains(modelId)) {
return getModelClusterByModelId(modelId, matchedModelClusters); return getModelClusterByModelId(modelId, matchedModelClusters);

View File

@@ -57,7 +57,7 @@ public class LLMRequestService {
return true; return true;
} }
if (SatisfactionChecker.isSkip(queryCtx)) { if (SatisfactionChecker.isSkip(queryCtx)) {
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getRequest().getQueryText()); log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getQueryText());
return true; return true;
} }
return false; return false;
@@ -121,7 +121,7 @@ public class LLMRequestService {
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema, public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
ModelCluster modelCluster, List<ElementValue> linkingValues) { ModelCluster modelCluster, List<ElementValue> linkingValues) {
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName(); Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
String queryText = queryCtx.getRequest().getQueryText(); String queryText = queryCtx.getQueryText();
LLMReq llmReq = new LLMReq(); LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText); llmReq.setQueryText(queryText);

View File

@@ -39,7 +39,7 @@ public class LLMResponseService {
properties.put("name", commonAgentTool.getName()); properties.put("name", commonAgentTool.getName());
parseInfo.setProperties(properties); parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight)); parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode()); parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setS2SQL(s2SQL); parseInfo.getSqlInfo().setS2SQL(s2SQL);
parseInfo.setModel(parseResult.getModelCluster()); parseInfo.setModel(parseResult.getModelCluster());

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.core.parser.sql.llm; package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.agent.NL2SQLTool; import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
import com.tencent.supersonic.chat.core.parser.SemanticParser; import com.tencent.supersonic.chat.core.parser.SemanticParser;
import com.tencent.supersonic.chat.core.pojo.ChatContext; import com.tencent.supersonic.chat.core.pojo.ChatContext;
@@ -24,7 +23,6 @@ public class LLMSqlParser implements SemanticParser {
@Override @Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) { public void parse(QueryContext queryCtx, ChatContext chatCtx) {
QueryReq request = queryCtx.getRequest();
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
//1.determine whether to skip this parser. //1.determine whether to skip this parser.
if (requestService.isSkip(queryCtx)) { if (requestService.isSkip(queryCtx)) {
@@ -56,7 +54,6 @@ public class LLMSqlParser implements SemanticParser {
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp); Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
ParseResult parseResult = ParseResult.builder() ParseResult parseResult = ParseResult.builder()
.request(request)
.modelCluster(modelCluster) .modelCluster(modelCluster)
.commonAgentTool(commonAgentTool) .commonAgentTool(commonAgentTool)
.llmReq(llmReq) .llmReq(llmReq)

View File

@@ -49,7 +49,7 @@ public class AggregateTypeParser implements SemanticParser {
@Override @Override
public void parse(QueryContext queryContext, ChatContext chatContext) { public void parse(QueryContext queryContext, ChatContext chatContext) {
String queryText = queryContext.getRequest().getQueryText(); String queryText = queryContext.getQueryText();
AggregateConf aggregateConf = resolveAggregateConf(queryText); AggregateConf aggregateConf = resolveAggregateConf(queryText);
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) { for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {

View File

@@ -42,12 +42,13 @@ public class TimeRangeParser implements SemanticParser {
@Override @Override
public void parse(QueryContext queryContext, ChatContext chatContext) { public void parse(QueryContext queryContext, ChatContext chatContext) {
DateConf dateConf = parseRecent(queryContext.getRequest().getQueryText()); String queryText = queryContext.getQueryText();
DateConf dateConf = parseRecent(queryText);
if (dateConf == null) { if (dateConf == null) {
dateConf = parseDateNumber(queryContext.getRequest().getQueryText()); dateConf = parseDateNumber(queryText);
} }
if (dateConf == null) { if (dateConf == null) {
dateConf = parseDateCN(queryContext.getRequest().getQueryText()); dateConf = parseDateCN(queryText);
} }
if (dateConf != null) { if (dateConf != null) {

View File

@@ -1,12 +1,12 @@
package com.tencent.supersonic.chat.core.plugin; package com.tencent.supersonic.chat.core.plugin;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import lombok.Data; import lombok.Data;
@Data @Data
public class PluginParseResult { public class PluginParseResult {
private Plugin plugin; private Plugin plugin;
private QueryReq request; private QueryFilters queryFilters;
private double distance; private double distance;
} }

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.chat.core.pojo; package com.tencent.supersonic.chat.core.pojo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo; import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp; import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.core.agent.Agent; import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.chat.core.config.OptimizationConfig; import com.tencent.supersonic.chat.core.config.OptimizationConfig;
@@ -27,7 +27,12 @@ import lombok.NoArgsConstructor;
@AllArgsConstructor @AllArgsConstructor
public class QueryContext { public class QueryContext {
private QueryReq request; private String queryText;
private Integer chatId;
private Long modelId;
private User user;
private boolean saveAnswer = true;
private Integer agentId;
private QueryFilters queryFilters; private QueryFilters queryFilters;
private List<SemanticQuery> candidateQueries = new ArrayList<>(); private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SchemaMapInfo mapInfo = new SchemaMapInfo();

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.core.plugin.PluginParseResult; import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
import com.tencent.supersonic.chat.core.query.BaseSemanticQuery; import com.tencent.supersonic.chat.core.query.BaseSemanticQuery;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -31,11 +30,10 @@ public abstract class PluginSemanticQuery extends BaseSemanticQuery {
private Map<Long, Object> getFilterMap(PluginParseResult pluginParseResult) { private Map<Long, Object> getFilterMap(PluginParseResult pluginParseResult) {
Map<Long, Object> map = new HashMap<>(); Map<Long, Object> map = new HashMap<>();
QueryReq queryReq = pluginParseResult.getRequest(); QueryFilters queryFilters = pluginParseResult.getQueryFilters();
if (queryReq == null || queryReq.getQueryFilters() == null) { if (queryFilters == null) {
return map; return map;
} }
QueryFilters queryFilters = queryReq.getQueryFilters();
List<QueryFilter> queryFilterList = queryFilters.getFilters(); List<QueryFilter> queryFilterList = queryFilters.getFilters();
if (CollectionUtils.isEmpty(queryFilterList)) { if (CollectionUtils.isEmpty(queryFilterList)) {
return map; return map;

View File

@@ -36,7 +36,7 @@ public class MetricTopNQuery extends MetricSemanticQuery {
@Override @Override
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches, public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
QueryContext queryCtx) { QueryContext queryCtx) {
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getRequest().getQueryText()); Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getQueryText());
if (matcher.matches()) { if (matcher.matches()) {
return super.match(candidateElementMatches, queryCtx); return super.match(candidateElementMatches, queryCtx);
} }

View File

@@ -1,15 +1,14 @@
package com.tencent.supersonic.chat.server.persistence.repository; package com.tencent.supersonic.chat.server.persistence.repository;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import java.util.List; import java.util.List;
public interface ChatQueryRepository { public interface ChatQueryRepository {
@@ -22,15 +21,12 @@ public interface ChatQueryRepository {
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId); List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
void updateChatParseInfo(List<ChatParseDO> chatParseDOS);
ChatQueryDO getLastChatQuery(long chatId); ChatQueryDO getLastChatQuery(long chatId);
int updateChatQuery(ChatQueryDO chatQueryDO); int updateChatQuery(ChatQueryDO chatQueryDO);
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq, List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryContext queryContext,
ParseResp parseResult, ParseResp parseResult, List<SemanticParseInfo> candidateParses);
List<SemanticParseInfo> candidateParses);
ChatParseDO getParseInfo(Long questionId, int parseId); ChatParseDO getParseInfo(Long questionId, int parseId);

View File

@@ -3,35 +3,35 @@ package com.tencent.supersonic.chat.server.persistence.repository.impl;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.PageHelper; import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo; import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp; import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.server.persistence.mapper.custom.ShowCaseCustomMapper; import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample.Criteria; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample.Criteria;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatParseMapper; import com.tencent.supersonic.chat.server.persistence.mapper.ChatParseMapper;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatQueryDOMapper; import com.tencent.supersonic.chat.server.persistence.mapper.ChatQueryDOMapper;
import com.tencent.supersonic.chat.server.persistence.mapper.custom.ShowCaseCustomMapper;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.PageUtils; import com.tencent.supersonic.common.util.PageUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.context.annotation.Primary; import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository; import org.springframework.stereotype.Repository;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@Repository @Repository
@Primary @Primary
@@ -116,13 +116,13 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
return queryResp; return queryResp;
} }
public Long createChatQuery(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq) { public Long createChatQuery(ParseResp parseResult, ChatContext chatCtx, QueryContext queryContext) {
ChatQueryDO chatQueryDO = new ChatQueryDO(); ChatQueryDO chatQueryDO = new ChatQueryDO();
chatQueryDO.setChatId(Long.valueOf(chatCtx.getChatId())); chatQueryDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatQueryDO.setCreateTime(new java.util.Date()); chatQueryDO.setCreateTime(new java.util.Date());
chatQueryDO.setUserName(queryReq.getUser().getName()); chatQueryDO.setUserName(queryContext.getUser().getName());
chatQueryDO.setQueryText(queryReq.getQueryText()); chatQueryDO.setQueryText(queryContext.getQueryText());
chatQueryDO.setAgentId(queryReq.getAgentId()); chatQueryDO.setAgentId(queryContext.getAgentId());
chatQueryDO.setQueryResult(""); chatQueryDO.setQueryResult("");
try { try {
chatQueryDOMapper.insert(chatQueryDO); chatQueryDOMapper.insert(chatQueryDO);
@@ -135,31 +135,24 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
} }
@Override @Override
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq, public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryContext queryContext,
ParseResp parseResult, List<SemanticParseInfo> candidateParses) { ParseResp parseResult, List<SemanticParseInfo> candidateParses) {
Long queryId = createChatQuery(parseResult, chatCtx, queryReq); Long queryId = createChatQuery(parseResult, chatCtx, queryContext);
List<ChatParseDO> chatParseDOList = new ArrayList<>(); List<ChatParseDO> chatParseDOList = new ArrayList<>();
getChatParseDO(chatCtx, queryReq, queryId, candidateParses, chatParseDOList); getChatParseDO(chatCtx, queryContext, queryId, candidateParses, chatParseDOList);
if (!CollectionUtils.isEmpty(candidateParses)) { if (!CollectionUtils.isEmpty(candidateParses)) {
chatParseMapper.batchSaveParseInfo(chatParseDOList); chatParseMapper.batchSaveParseInfo(chatParseDOList);
} }
return chatParseDOList; return chatParseDOList;
} }
@Override public void getChatParseDO(ChatContext chatCtx, QueryContext queryContext, Long queryId,
public void updateChatParseInfo(List<ChatParseDO> chatParseDOS) {
for (ChatParseDO chatParseDO : chatParseDOS) {
chatParseMapper.updateParseInfo(chatParseDO);
}
}
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId,
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) { List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
for (int i = 0; i < parses.size(); i++) { for (int i = 0; i < parses.size(); i++) {
ChatParseDO chatParseDO = new ChatParseDO(); ChatParseDO chatParseDO = new ChatParseDO();
chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId())); chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatParseDO.setQuestionId(queryId); chatParseDO.setQuestionId(queryId);
chatParseDO.setQueryText(queryReq.getQueryText()); chatParseDO.setQueryText(queryContext.getQueryText());
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i))); chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
chatParseDO.setIsCandidate(1); chatParseDO.setIsCandidate(1);
if (i == 0) { if (i == 0) {
@@ -167,7 +160,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
} }
chatParseDO.setParseId(parses.get(i).getId()); chatParseDO.setParseId(parses.get(i).getId());
chatParseDO.setCreateTime(new java.util.Date()); chatParseDO.setCreateTime(new java.util.Date());
chatParseDO.setUserName(queryReq.getUser().getName()); chatParseDO.setUserName(queryContext.getUser().getName());
chatParseDOList.add(chatParseDO); chatParseDOList.add(chatParseDO);
} }
} }

View File

@@ -1,13 +1,12 @@
package com.tencent.supersonic.chat.server.processor.parse; package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager; import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.analytics.MetricAnalyzeQuery; import com.tencent.supersonic.chat.core.query.llm.analytics.MetricAnalyzeQuery;
import com.tencent.supersonic.chat.server.service.SemanticService; import com.tencent.supersonic.chat.server.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -29,7 +28,6 @@ public class EntityInfoProcessor implements ParseResultProcessor {
} }
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo) List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
.collect(Collectors.toList()); .collect(Collectors.toList());
QueryReq queryReq = queryContext.getRequest();
selectedParses.forEach(parseInfo -> { selectedParses.forEach(parseInfo -> {
String queryMode = parseInfo.getQueryMode(); String queryMode = parseInfo.getQueryMode();
if (QueryManager.containsPluginQuery(queryMode) if (QueryManager.containsPluginQuery(queryMode)
@@ -38,7 +36,7 @@ public class EntityInfoProcessor implements ParseResultProcessor {
} }
//1. set entity info //1. set entity info
SemanticService semanticService = ContextUtils.getBean(SemanticService.class); SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryReq.getUser()); EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryContext.getUser());
if (QueryManager.isTagQuery(queryMode) if (QueryManager.isTagQuery(queryMode)
|| QueryManager.isMetricQuery(queryMode)) { || QueryManager.isMetricQuery(queryMode)) {
parseInfo.setEntityInfo(entityInfo); parseInfo.setEntityInfo(entityInfo);

View File

@@ -35,8 +35,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
@SneakyThrows @SneakyThrows
private void doProcess(ParseResp parseResp, QueryContext queryContext) { private void doProcess(ParseResp parseResp, QueryContext queryContext) {
Long queryId = parseResp.getQueryId(); Long queryId = parseResp.getQueryId();
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(queryContext.getRequest().getQueryText(), List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(queryContext.getQueryText(),
queryContext.getRequest().getAgentId()); queryContext.getAgentId());
ChatQueryDO chatQueryDO = getChatQuery(queryId); ChatQueryDO chatQueryDO = getChatQuery(queryId);
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries)); chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
updateChatQuery(chatQueryDO); updateChatQuery(chatQueryDO);

View File

@@ -1,11 +1,10 @@
package com.tencent.supersonic.chat.server.processor.parse; package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext; import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.query.SemanticQuery; import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.server.service.ChatService; import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List; import java.util.List;
@@ -20,9 +19,8 @@ public class RespBuildProcessor implements ParseResultProcessor {
@Override @Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) { public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
QueryReq queryReq = queryContext.getRequest(); parseResp.setChatId(queryContext.getChatId());
parseResp.setChatId(queryReq.getChatId()); parseResp.setQueryText(queryContext.getQueryText());
parseResp.setQueryText(queryReq.getQueryText());
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries(); List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
ChatService chatService = ContextUtils.getBean(ChatService.class); ChatService chatService = ContextUtils.getBean(ChatService.class);
if (candidateQueries.size() > 0) { if (candidateQueries.size() > 0) {
@@ -33,7 +31,7 @@ public class RespBuildProcessor implements ParseResultProcessor {
} else { } else {
parseResp.setState(ParseResp.ParseState.FAILED); parseResp.setState(ParseResp.ParseState.FAILED);
} }
chatService.batchAddParse(chatContext, queryReq, parseResp); chatService.batchAddParse(chatContext, queryContext, parseResp);
} }
} }

View File

@@ -1,21 +1,20 @@
package com.tencent.supersonic.chat.server.processor.parse; package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo; import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager; import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/** /**
* SqlInfoProcessor adds S2SQL to the parsing results so that * SqlInfoProcessor adds S2SQL to the parsing results so that
@@ -27,7 +26,6 @@ public class SqlInfoProcessor implements ParseResultProcessor {
@Override @Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) { public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
QueryReq queryReq = queryContext.getRequest();
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries(); List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(semanticQueries)) { if (CollectionUtils.isEmpty(semanticQueries)) {
return; return;
@@ -35,26 +33,26 @@ public class SqlInfoProcessor implements ParseResultProcessor {
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo) List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
.collect(Collectors.toList()); .collect(Collectors.toList());
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
addSqlInfo(queryReq, selectedParses); addSqlInfo(queryContext, selectedParses);
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime); parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime);
} }
private void addSqlInfo(QueryReq queryReq, List<SemanticParseInfo> semanticParseInfos) { private void addSqlInfo(QueryContext queryContext, List<SemanticParseInfo> semanticParseInfos) {
if (CollectionUtils.isEmpty(semanticParseInfos)) { if (CollectionUtils.isEmpty(semanticParseInfos)) {
return; return;
} }
semanticParseInfos.forEach(parseInfo -> { semanticParseInfos.forEach(parseInfo -> {
addSqlInfo(queryReq, parseInfo); addSqlInfo(queryContext, parseInfo);
}); });
} }
private void addSqlInfo(QueryReq queryReq, SemanticParseInfo parseInfo) { private void addSqlInfo(QueryContext queryContext, SemanticParseInfo parseInfo) {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode()); SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (Objects.isNull(semanticQuery)) { if (Objects.isNull(semanticQuery)) {
return; return;
} }
semanticQuery.setParseInfo(parseInfo); semanticQuery.setParseInfo(parseInfo);
String explainSql = semanticQuery.explain(queryReq.getUser()); String explainSql = semanticQuery.explain(queryContext.getUser());
if (StringUtils.isBlank(explainSql)) { if (StringUtils.isBlank(explainSql)) {
return; return;
} }

View File

@@ -4,11 +4,11 @@ import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.pojo.ChatContext; import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp; import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
@@ -46,7 +46,7 @@ public interface ChatService {
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId); ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult); List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryContext queryContext, ParseResp parseResult);
ChatQueryDO getLastQuery(long chatId); ChatQueryDO getLastQuery(long chatId);

View File

@@ -5,11 +5,11 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.pojo.ChatContext; import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp; import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
@@ -211,9 +211,9 @@ public class ChatServiceImpl implements ChatService {
} }
@Override @Override
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult) { public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryContext queryContext, ParseResp parseResult) {
List<SemanticParseInfo> candidateParses = parseResult.getSelectedParses(); List<SemanticParseInfo> candidateParses = parseResult.getSelectedParses();
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses); return chatQueryRepository.batchSaveParseInfo(chatCtx, queryContext, parseResult, candidateParses);
} }
@Override @Override

View File

@@ -90,6 +90,7 @@ import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Primary; import org.springframework.context.annotation.Primary;
@@ -197,7 +198,6 @@ public class QueryServiceImpl implements QueryService {
List<Plugin> pluginList = pluginService.getPluginList(); List<Plugin> pluginList = pluginService.getPluginList();
QueryContext queryCtx = QueryContext.builder() QueryContext queryCtx = QueryContext.builder()
.request(queryReq)
.queryFilters(queryReq.getQueryFilters()) .queryFilters(queryReq.getQueryFilters())
.semanticSchema(semanticSchema) .semanticSchema(semanticSchema)
.candidateQueries(new ArrayList<>()) .candidateQueries(new ArrayList<>())
@@ -207,6 +207,7 @@ public class QueryServiceImpl implements QueryService {
.nameToPlugin(nameToPlugin) .nameToPlugin(nameToPlugin)
.pluginList(pluginList) .pluginList(pluginList)
.build(); .build();
BeanUtils.copyProperties(queryReq, queryCtx);
return queryCtx; return queryCtx;
} }

View File

@@ -42,6 +42,7 @@ import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -91,10 +92,10 @@ public class SearchServiceImpl implements SearchService {
List<Term> originals = HanlpHelper.getTerms(queryText); List<Term> originals = HanlpHelper.getTerms(queryText);
log.info("hanlp parse result: {}", originals); log.info("hanlp parse result: {}", originals);
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq, agentService.getAgent(agentId)); Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq.getModelId(), agentService.getAgent(agentId));
QueryContext queryContext = new QueryContext(); QueryContext queryContext = new QueryContext();
queryContext.setRequest(queryReq); BeanUtils.copyProperties(queryReq, queryContext);
Map<MatchText, List<HanlpMapResult>> regTextMap = Map<MatchText, List<HanlpMapResult>> regTextMap =
searchMatchStrategy.match(queryContext, originals, detectModelIds); searchMatchStrategy.match(queryContext, originals, detectModelIds);
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue())); regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));

View File

@@ -1,11 +1,10 @@
package com.tencent.supersonic.headless.server.service; package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.headless.api.request.MetricQueryReq;
import com.tencent.supersonic.headless.api.request.ParseSqlReq; import com.tencent.supersonic.headless.api.request.ParseSqlReq;
import com.tencent.supersonic.headless.api.request.QueryStructReq; import com.tencent.supersonic.headless.api.request.QueryStructReq;
import com.tencent.supersonic.headless.api.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.executor.QueryExecutor; import com.tencent.supersonic.headless.core.executor.QueryExecutor;
import com.tencent.supersonic.headless.core.pojo.QueryStatement;
public interface SemantciQueryEngine { public interface SemantciQueryEngine {
@@ -17,5 +16,4 @@ public interface SemantciQueryEngine {
QueryStatement physicalSql(QueryStructReq queryStructCmd, ParseSqlReq sqlCommend) throws Exception; QueryStatement physicalSql(QueryStructReq queryStructCmd, ParseSqlReq sqlCommend) throws Exception;
QueryStatement physicalSql(QueryStructReq queryStructCmd, MetricQueryReq sqlCommend) throws Exception;
} }

View File

@@ -1,13 +1,12 @@
package com.tencent.supersonic.headless.server.service.impl; package com.tencent.supersonic.headless.server.service.impl;
import com.tencent.supersonic.headless.api.request.MetricQueryReq;
import com.tencent.supersonic.headless.api.request.ParseSqlReq; import com.tencent.supersonic.headless.api.request.ParseSqlReq;
import com.tencent.supersonic.headless.api.request.QueryStructReq; import com.tencent.supersonic.headless.api.request.QueryStructReq;
import com.tencent.supersonic.headless.api.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.executor.QueryExecutor; import com.tencent.supersonic.headless.core.executor.QueryExecutor;
import com.tencent.supersonic.headless.core.planner.QueryOptimizer;
import com.tencent.supersonic.headless.core.parser.QueryParser; import com.tencent.supersonic.headless.core.parser.QueryParser;
import com.tencent.supersonic.headless.core.parser.calcite.s2sql.SemanticModel; import com.tencent.supersonic.headless.core.parser.calcite.s2sql.SemanticModel;
import com.tencent.supersonic.headless.core.planner.QueryOptimizer;
import com.tencent.supersonic.headless.core.pojo.QueryStatement; import com.tencent.supersonic.headless.core.pojo.QueryStatement;
import com.tencent.supersonic.headless.core.utils.ComponentFactory; import com.tencent.supersonic.headless.core.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.manager.SemanticSchemaManager; import com.tencent.supersonic.headless.server.manager.SemanticSchemaManager;
@@ -82,15 +81,6 @@ public class SemantciQueryEngineImpl implements SemantciQueryEngine {
return optimize(queryStructCmd, queryParser.parser(sqlCommend, queryStatement)); return optimize(queryStructCmd, queryParser.parser(sqlCommend, queryStatement));
} }
public QueryStatement physicalSql(QueryStructReq queryStructCmd, MetricQueryReq metricCommand) throws Exception {
QueryStatement queryStatement = new QueryStatement();
queryStatement.setQueryStructReq(queryStructCmd);
queryStatement.setMetricReq(metricCommand);
queryStatement.setIsS2SQL(false);
queryStatement.setSemanticModel(getSemanticModel(queryStatement));
return queryParser.parser(queryStatement);
}
private SemanticModel getSemanticModel(QueryStatement queryStatement) throws Exception { private SemanticModel getSemanticModel(QueryStatement queryStatement) throws Exception {
QueryStructReq queryStructReq = queryStatement.getQueryStructReq(); QueryStructReq queryStructReq = queryStatement.getQueryStructReq();
return semanticSchemaManager.get(queryStructReq.getModelIdStr()); return semanticSchemaManager.get(queryStructReq.getModelIdStr());