mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement][headless-chat]Incorporate Request into Context objects, removing unnecessary copy.
This commit is contained in:
@@ -7,7 +7,6 @@ import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import lombok.Data;
|
||||
@@ -17,7 +16,7 @@ import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
public class QueryNLReq {
|
||||
public class QueryNLReq extends SemanticQueryReq {
|
||||
private String queryText;
|
||||
private Set<Long> dataSetIds = Sets.newHashSet();
|
||||
private User user;
|
||||
@@ -25,9 +24,13 @@ public class QueryNLReq {
|
||||
private boolean saveAnswer = true;
|
||||
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private Map<String, ChatApp> chatAppConfig;
|
||||
private List<Text2SQLExemplar> dynamicExemplars = Lists.newArrayList();
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
|
||||
@Override
|
||||
public String toCustomizedString() {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,12 +50,4 @@ public abstract class SemanticQueryReq {
|
||||
public Set<Long> getModelIdSet() {
|
||||
return modelIds;
|
||||
}
|
||||
|
||||
public boolean isNeedAuth() {
|
||||
return needAuth;
|
||||
}
|
||||
|
||||
public void setNeedAuth(boolean needAuth) {
|
||||
this.needAuth = needAuth;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,63 +1,41 @@
|
||||
package com.tencent.supersonic.headless.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class ChatQueryContext {
|
||||
|
||||
private String queryText;
|
||||
private QueryNLReq request;
|
||||
private String oriQueryText;
|
||||
private Set<Long> dataSetIds;
|
||||
private Map<Long, List<Long>> modelIdToDataSetIds;
|
||||
private User user;
|
||||
private boolean saveAnswer;
|
||||
private QueryFilters queryFilters;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
@Builder.Default
|
||||
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
||||
@Builder.Default
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
@Builder.Default
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
@JsonIgnore
|
||||
private SemanticSchema semanticSchema;
|
||||
@JsonIgnore
|
||||
private ChatWorkflowState chatWorkflowState;
|
||||
@JsonIgnore
|
||||
private Map<String, ChatApp> chatAppConfig;
|
||||
@JsonIgnore
|
||||
private List<Text2SQLExemplar> dynamicExemplars;
|
||||
|
||||
public ChatQueryContext() {
|
||||
this(new QueryNLReq());
|
||||
}
|
||||
|
||||
public ChatQueryContext(QueryNLReq request) {
|
||||
this.request = request;
|
||||
}
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
|
||||
@@ -61,8 +61,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
ChatApp chatApp = chatQueryContext.getChatAppConfig().get(APP_KEY);
|
||||
if (!chatQueryContext.getText2SQLType().enableLLM() || Objects.isNull(chatApp)
|
||||
ChatApp chatApp = chatQueryContext.getRequest().getChatAppConfig().get(APP_KEY);
|
||||
if (!chatQueryContext.getRequest().getText2SQLType().enableLLM() || Objects.isNull(chatApp)
|
||||
|| !chatApp.isEnable()) {
|
||||
return;
|
||||
}
|
||||
@@ -71,8 +71,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
SemanticSqlExtractor extractor =
|
||||
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||
Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo,
|
||||
chatApp.getPrompt());
|
||||
Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(),
|
||||
semanticParseInfo, chatApp.getPrompt());
|
||||
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
|
||||
keyPipelineLog.info("LLMSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(), s2Sql);
|
||||
if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) {
|
||||
|
||||
@@ -33,7 +33,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
protected void addQueryFilter(ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo semanticParseInfo) {
|
||||
String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters());
|
||||
String queryFilter = getQueryFilter(chatQueryContext.getRequest().getQueryFilters());
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
|
||||
@@ -116,12 +116,12 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
|
||||
public <T> List<T> getMatches(ChatQueryContext chatQueryContext,
|
||||
BaseMatchStrategy matchStrategy) {
|
||||
String queryText = chatQueryContext.getQueryText();
|
||||
String queryText = chatQueryContext.getRequest().getQueryText();
|
||||
List<S2Term> terms =
|
||||
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
||||
terms = HanlpHelper.getTerms(terms, chatQueryContext.getDataSetIds());
|
||||
Map<MatchText, List<T>> matchResult =
|
||||
matchStrategy.match(chatQueryContext, terms, chatQueryContext.getDataSetIds());
|
||||
terms = HanlpHelper.getTerms(terms, chatQueryContext.getRequest().getDataSetIds());
|
||||
Map<MatchText, List<T>> matchResult = matchStrategy.match(chatQueryContext, terms,
|
||||
chatQueryContext.getRequest().getDataSetIds());
|
||||
List<T> matches = new ArrayList<>();
|
||||
if (Objects.isNull(matchResult)) {
|
||||
return matches;
|
||||
|
||||
@@ -21,7 +21,7 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = chatQueryContext.getQueryText();
|
||||
String text = chatQueryContext.getRequest().getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ public abstract class BatchMatchStrategy<T extends MapResult> extends BaseMatchS
|
||||
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
|
||||
String text = chatQueryContext.getQueryText();
|
||||
String text = chatQueryContext.getRequest().getQueryText();
|
||||
Set<String> detectSegments = new HashSet<>();
|
||||
|
||||
int embeddingTextSize = Integer
|
||||
|
||||
@@ -93,7 +93,8 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
|
||||
log.debug("ModelElementMatches:{},not exist Element threshold reduce by half:{}",
|
||||
modelElementMatches, threshold);
|
||||
}
|
||||
return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum());
|
||||
return getThreshold(threshold, minThreshold,
|
||||
chatQueryContext.getRequest().getMapModeEnum());
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||
|
||||
@@ -69,7 +69,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
||||
double embeddingThresholdMin =
|
||||
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
|
||||
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin,
|
||||
chatQueryContext.getMapModeEnum());
|
||||
chatQueryContext.getRequest().getMapModeEnum());
|
||||
|
||||
// step1. build query params
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
|
||||
@@ -105,6 +105,7 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy<HanlpMapResult>
|
||||
mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD_MIN));
|
||||
}
|
||||
|
||||
return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum());
|
||||
return getThreshold(threshold, minThreshold,
|
||||
chatQueryContext.getRequest().getMapModeEnum());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
String queryText = chatQueryContext.getQueryText();
|
||||
String queryText = chatQueryContext.getRequest().getQueryText();
|
||||
// 1.hanlpDict Match
|
||||
List<S2Term> terms =
|
||||
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
||||
|
||||
@@ -22,7 +22,7 @@ public class MapFilter {
|
||||
filterByDataSetId(chatQueryContext);
|
||||
filterByDetectWordLenLessThanOne(chatQueryContext);
|
||||
twoCharactersMustEqual(chatQueryContext);
|
||||
switch (chatQueryContext.getQueryDataType()) {
|
||||
switch (chatQueryContext.getRequest().getQueryDataType()) {
|
||||
case TAG:
|
||||
filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0));
|
||||
break;
|
||||
@@ -46,7 +46,7 @@ public class MapFilter {
|
||||
}
|
||||
|
||||
public static void filterByDataSetId(ChatQueryContext chatQueryContext) {
|
||||
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
|
||||
Set<Long> dataSetIds = chatQueryContext.getRequest().getDataSetIds();
|
||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ public class QueryFilterMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
|
||||
Set<Long> dataSetIds = chatQueryContext.getRequest().getDataSetIds();
|
||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return;
|
||||
}
|
||||
@@ -53,7 +53,7 @@ public class QueryFilterMapper extends BaseMapper {
|
||||
|
||||
private void addValueSchemaElementMatch(Long dataSetId, ChatQueryContext chatQueryContext,
|
||||
List<SchemaElementMatch> candidateElementMatches) {
|
||||
QueryFilters queryFilters = chatQueryContext.getQueryFilters();
|
||||
QueryFilters queryFilters = chatQueryContext.getRequest().getQueryFilters();
|
||||
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -36,7 +36,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(ChatQueryContext chatQueryContext,
|
||||
List<S2Term> originals, Set<Long> detectDataSetIds) {
|
||||
String text = chatQueryContext.getQueryText();
|
||||
String text = chatQueryContext.getRequest().getQueryText();
|
||||
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(originals);
|
||||
|
||||
List<Integer> detectIndexList = Lists.newArrayList();
|
||||
|
||||
@@ -24,7 +24,7 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
|
||||
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
|
||||
String text = chatQueryContext.getQueryText();
|
||||
String text = chatQueryContext.getRequest().getQueryText();
|
||||
Set<T> results = new HashSet<>();
|
||||
|
||||
Set<String> detectSegments = new HashSet<>();
|
||||
|
||||
@@ -20,21 +20,22 @@ public class TermDescMapper extends BaseMapper {
|
||||
return;
|
||||
}
|
||||
if (StringUtils.isBlank(chatQueryContext.getOriQueryText())) {
|
||||
chatQueryContext.setOriQueryText(chatQueryContext.getQueryText());
|
||||
chatQueryContext.setOriQueryText(chatQueryContext.getRequest().getQueryText());
|
||||
}
|
||||
for (SchemaElement schemaElement : termDescriptionToMap) {
|
||||
if (schemaElement.isDescriptionMapped()) {
|
||||
continue;
|
||||
}
|
||||
if (chatQueryContext.getQueryText().equals(schemaElement.getDescription())) {
|
||||
if (chatQueryContext.getRequest().getQueryText()
|
||||
.equals(schemaElement.getDescription())) {
|
||||
schemaElement.setDescriptionMapped(true);
|
||||
continue;
|
||||
}
|
||||
chatQueryContext.setQueryText(schemaElement.getDescription());
|
||||
chatQueryContext.getRequest().setQueryText(schemaElement.getDescription());
|
||||
break;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(chatQueryContext.getMapInfo().getTermDescriptionToMap())) {
|
||||
chatQueryContext.setQueryText(chatQueryContext.getOriQueryText());
|
||||
chatQueryContext.getRequest().setQueryText(chatQueryContext.getOriQueryText());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
|
||||
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
|
||||
User user = chatQueryContext.getUser();
|
||||
User user = chatQueryContext.getRequest().getUser();
|
||||
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
// 1.init S2SQL
|
||||
|
||||
@@ -25,7 +25,8 @@ public class SatisfactionChecker {
|
||||
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||
continue;
|
||||
}
|
||||
if (checkThreshold(chatQueryContext.getQueryText(), query.getParseInfo())) {
|
||||
if (checkThreshold(chatQueryContext.getRequest().getQueryText(),
|
||||
query.getParseInfo())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ public class LLMRequestService {
|
||||
private ParserConfig parserConfig;
|
||||
|
||||
public boolean isSkip(ChatQueryContext queryCtx) {
|
||||
if (!queryCtx.getText2SQLType().enableLLM()) {
|
||||
if (!queryCtx.getRequest().getText2SQLType().enableLLM()) {
|
||||
log.info("LLM disabled, skip");
|
||||
return true;
|
||||
}
|
||||
@@ -45,12 +45,12 @@ public class LLMRequestService {
|
||||
|
||||
public Long getDataSetId(ChatQueryContext queryCtx) {
|
||||
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
|
||||
return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds());
|
||||
return dataSetResolver.resolve(queryCtx, queryCtx.getRequest().getDataSetIds());
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
Map<Long, String> dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName();
|
||||
String queryText = queryCtx.getQueryText();
|
||||
String queryText = queryCtx.getRequest().getQueryText();
|
||||
|
||||
LLMReq llmReq = new LLMReq();
|
||||
llmReq.setQueryText(queryText);
|
||||
@@ -74,8 +74,8 @@ public class LLMRequestService {
|
||||
llmReq.setTerms(getMappedTerms(queryCtx, dataSetId));
|
||||
llmReq.setSqlGenType(
|
||||
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||
llmReq.setChatAppConfig(queryCtx.getChatAppConfig());
|
||||
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
|
||||
llmReq.setChatAppConfig(queryCtx.getRequest().getChatAppConfig());
|
||||
llmReq.setDynamicExemplars(queryCtx.getRequest().getDynamicExemplars());
|
||||
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
@@ -39,13 +39,14 @@ public class LLMResponseService {
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, parseResult);
|
||||
properties.put("type", "internal");
|
||||
Text2SQLExemplar exemplar = Text2SQLExemplar.builder().question(queryCtx.getQueryText())
|
||||
.sideInfo(parseResult.getLlmResp().getSideInfo())
|
||||
.dbSchema(parseResult.getLlmResp().getSchema())
|
||||
.sql(parseResult.getLlmResp().getSqlOutput()).build();
|
||||
Text2SQLExemplar exemplar =
|
||||
Text2SQLExemplar.builder().question(queryCtx.getRequest().getQueryText())
|
||||
.sideInfo(parseResult.getLlmResp().getSideInfo())
|
||||
.dbSchema(parseResult.getLlmResp().getSchema())
|
||||
.sql(parseResult.getLlmResp().getSqlOutput()).build();
|
||||
properties.put(Text2SQLExemplar.PROPERTY_KEY, exemplar);
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
|
||||
@@ -44,7 +44,7 @@ public class AggregateTypeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
String queryText = chatQueryContext.getQueryText();
|
||||
String queryText = chatQueryContext.getRequest().getQueryText();
|
||||
AggregateConf aggregateConf = resolveAggregateConf(queryText);
|
||||
|
||||
for (SemanticQuery semanticQuery : chatQueryContext.getCandidateQueries()) {
|
||||
|
||||
@@ -60,12 +60,12 @@ public class ContextInheritParser implements SemanticParser {
|
||||
chatQueryContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
|
||||
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
||||
for (SchemaElementMatch match : chatQueryContext.getContextParseInfo()
|
||||
for (SchemaElementMatch match : chatQueryContext.getRequest().getContextParseInfo()
|
||||
.getElementMatches()) {
|
||||
SchemaElementType matchType = match.getElement().getType();
|
||||
// mutual exclusive element types should not be inherited
|
||||
RuleSemanticQuery ruleQuery = QueryManager
|
||||
.getRuleQuery(chatQueryContext.getContextParseInfo().getQueryMode());
|
||||
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(
|
||||
chatQueryContext.getRequest().getContextParseInfo().getQueryMode());
|
||||
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
|
||||
match.setInherited(true);
|
||||
matchesToInherit.add(match);
|
||||
@@ -121,10 +121,13 @@ public class ContextInheritParser implements SemanticParser {
|
||||
}
|
||||
|
||||
protected Long getMatchedDataSet(ChatQueryContext chatQueryContext) {
|
||||
Long dataSetId = chatQueryContext.getContextParseInfo().getDataSetId();
|
||||
if (dataSetId == null) {
|
||||
if (Objects.isNull(chatQueryContext)
|
||||
|| Objects.isNull(chatQueryContext.getRequest().getContextParseInfo())
|
||||
|| Objects.isNull(
|
||||
chatQueryContext.getRequest().getContextParseInfo().getDataSetId())) {
|
||||
return null;
|
||||
}
|
||||
Long dataSetId = chatQueryContext.getRequest().getContextParseInfo().getDataSetId();
|
||||
Set<Long> queryDataSets = chatQueryContext.getMapInfo().getMatchedDataSetInfos();
|
||||
if (queryDataSets.contains(dataSetId)) {
|
||||
return dataSetId;
|
||||
|
||||
@@ -22,7 +22,7 @@ public class RuleSqlParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
if (!chatQueryContext.getText2SQLType().enableRule()
|
||||
if (!chatQueryContext.getRequest().getText2SQLType().enableRule()
|
||||
|| !chatQueryContext.getCandidateQueries().isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@ public class TimeRangeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext queryContext) {
|
||||
String queryText = queryContext.getQueryText();
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
DateConf dateConf = parseRecent(queryText);
|
||||
if (dateConf == null) {
|
||||
dateConf = parseDateNumber(queryText);
|
||||
@@ -62,7 +62,7 @@ public class TimeRangeParser implements SemanticParser {
|
||||
parseInfo.setScore(parseInfo.getScore() + dateConf.getDetectWord().length());
|
||||
}
|
||||
} else {
|
||||
SemanticParseInfo contextParseInfo = queryContext.getContextParseInfo();
|
||||
SemanticParseInfo contextParseInfo = queryContext.getRequest().getContextParseInfo();
|
||||
if (QueryManager.containsRuleQuery(contextParseInfo.getQueryMode())) {
|
||||
RuleSemanticQuery semanticQuery =
|
||||
QueryManager.createRuleQuery(contextParseInfo.getQueryMode());
|
||||
|
||||
@@ -69,8 +69,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
|
||||
private void fillDateConfByInherited(SemanticParseInfo queryParseInfo,
|
||||
ChatQueryContext chatQueryContext) {
|
||||
SemanticParseInfo contextParseInfo = chatQueryContext.getContextParseInfo();
|
||||
if (queryParseInfo.getDateInfo() != null || contextParseInfo.getDateInfo() == null
|
||||
SemanticParseInfo contextParseInfo = chatQueryContext.getRequest().getContextParseInfo();
|
||||
if (queryParseInfo.getDateInfo() != null || Objects.isNull(contextParseInfo)
|
||||
|| Objects.isNull(contextParseInfo.getDateInfo())
|
||||
|| needFillDateConf(chatQueryContext)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ public class MetricTopNQuery extends MetricSemanticQuery {
|
||||
@Override
|
||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
||||
ChatQueryContext queryCtx) {
|
||||
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getQueryText());
|
||||
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getRequest().getQueryText());
|
||||
if (matcher.matches()) {
|
||||
return super.match(candidateElementMatches, queryCtx);
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ class WhereCorrectorTest {
|
||||
queryFilters.getFilters().add(filter2);
|
||||
queryFilters.getFilters().add(filter3);
|
||||
queryFilters.getFilters().add(filter4);
|
||||
chatQueryContext.setQueryFilters(queryFilters);
|
||||
chatQueryContext.getRequest().setQueryFilters(queryFilters);
|
||||
|
||||
WhereCorrector whereCorrector = new WhereCorrector();
|
||||
whereCorrector.addQueryFilter(chatQueryContext, semanticParseInfo);
|
||||
|
||||
@@ -93,19 +93,6 @@ public class S2ChatLayerService implements ChatLayerService {
|
||||
return parseResult;
|
||||
}
|
||||
|
||||
private ChatQueryContext buildChatQueryContext(QueryNLReq queryNLReq) {
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema(queryNLReq.getDataSetIds());
|
||||
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds();
|
||||
ChatQueryContext queryCtx = ChatQueryContext.builder()
|
||||
.queryFilters(queryNLReq.getQueryFilters()).semanticSchema(semanticSchema)
|
||||
.candidateQueries(new ArrayList<>()).mapInfo(new SchemaMapInfo())
|
||||
.modelIdToDataSetIds(modelIdToDataSetIds).text2SQLType(queryNLReq.getText2SQLType())
|
||||
.mapModeEnum(queryNLReq.getMapModeEnum()).dataSetIds(queryNLReq.getDataSetIds())
|
||||
.build();
|
||||
BeanUtils.copyProperties(queryNLReq, queryCtx);
|
||||
return queryCtx;
|
||||
}
|
||||
|
||||
public void correct(QuerySqlReq querySqlReq, User user) {
|
||||
SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user);
|
||||
querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
@@ -122,6 +109,15 @@ public class S2ChatLayerService implements ChatLayerService {
|
||||
return retrieveService.retrieve(queryNLReq);
|
||||
}
|
||||
|
||||
private ChatQueryContext buildChatQueryContext(QueryNLReq queryNLReq) {
|
||||
ChatQueryContext queryCtx = new ChatQueryContext(queryNLReq);
|
||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema(queryNLReq.getDataSetIds());
|
||||
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds();
|
||||
queryCtx.setSemanticSchema(semanticSchema);
|
||||
queryCtx.setModelIdToDataSetIds(modelIdToDataSetIds);
|
||||
return queryCtx;
|
||||
}
|
||||
|
||||
private SemanticParseInfo correctSqlReq(QuerySqlReq querySqlReq, User user) {
|
||||
ChatQueryContext queryCtx = new ChatQueryContext();
|
||||
SemanticSchema semanticSchema =
|
||||
|
||||
@@ -26,7 +26,7 @@ public class EntityInfoProcessor implements ResultProcessor {
|
||||
DataSetSchema dataSetSchema =
|
||||
semanticService.getDataSetSchema(parseInfo.getDataSetId());
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema,
|
||||
chatQueryContext.getUser());
|
||||
chatQueryContext.getRequest().getUser());
|
||||
parseInfo.setEntityInfo(entityInfo);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -24,7 +24,6 @@ import com.tencent.supersonic.headless.server.service.RetrieveService;
|
||||
import com.tencent.supersonic.headless.server.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -75,8 +74,7 @@ public class RetrieveServiceImpl implements RetrieveService {
|
||||
log.debug("originals terms: {}", originals);
|
||||
Set<Long> dataSetIds = queryNLReq.getDataSetIds();
|
||||
|
||||
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||
BeanUtils.copyProperties(queryNLReq, chatQueryContext);
|
||||
ChatQueryContext chatQueryContext = new ChatQueryContext(queryNLReq);
|
||||
chatQueryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds());
|
||||
|
||||
Map<MatchText, List<HanlpMapResult>> regTextMap =
|
||||
|
||||
@@ -140,7 +140,7 @@ public class ChatWorkflowEngine {
|
||||
SemanticLayerService queryService =
|
||||
ContextUtils.getBean(SemanticLayerService.class);
|
||||
SemanticTranslateResp explain =
|
||||
queryService.translate(semanticQueryReq, queryCtx.getUser());
|
||||
queryService.translate(semanticQueryReq, queryCtx.getRequest().getUser());
|
||||
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
||||
if (StringUtils.isNotBlank(explain.getErrMsg())) {
|
||||
errorMsg.add(explain.getErrMsg());
|
||||
|
||||
@@ -48,7 +48,6 @@ public class MultiTurnsTest extends BaseTest {
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user