[improvement][headless-chat]Incorporate Request into Context objects, removing unnecessary copy.

This commit is contained in:
jerryjzhang
2024-10-27 18:47:34 +08:00
parent bd82b0904b
commit bb363a0286
33 changed files with 87 additions and 112 deletions

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import lombok.Data;
@@ -17,7 +16,7 @@ import java.util.Map;
import java.util.Set;
@Data
public class QueryNLReq {
public class QueryNLReq extends SemanticQueryReq {
private String queryText;
private Set<Long> dataSetIds = Sets.newHashSet();
private User user;
@@ -25,9 +24,13 @@ public class QueryNLReq {
private boolean saveAnswer = true;
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private QueryDataType queryDataType = QueryDataType.ALL;
private Map<String, ChatApp> chatAppConfig;
private List<Text2SQLExemplar> dynamicExemplars = Lists.newArrayList();
private SemanticParseInfo contextParseInfo;
@Override
public String toCustomizedString() {
return "";
}
}

View File

@@ -50,12 +50,4 @@ public abstract class SemanticQueryReq {
public Set<Long> getModelIdSet() {
return modelIds;
}
public boolean isNeedAuth() {
return needAuth;
}
public void setNeedAuth(boolean needAuth) {
this.needAuth = needAuth;
}
}

View File

@@ -1,63 +1,41 @@
package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ChatQueryContext {
private String queryText;
private QueryNLReq request;
private String oriQueryText;
private Set<Long> dataSetIds;
private Map<Long, List<Long>> modelIdToDataSetIds;
private User user;
private boolean saveAnswer;
private QueryFilters queryFilters;
private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private SemanticParseInfo contextParseInfo;
@Builder.Default
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
@Builder.Default
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
@Builder.Default
private QueryDataType queryDataType = QueryDataType.ALL;
@JsonIgnore
private SemanticSchema semanticSchema;
@JsonIgnore
private ChatWorkflowState chatWorkflowState;
@JsonIgnore
private Map<String, ChatApp> chatAppConfig;
@JsonIgnore
private List<Text2SQLExemplar> dynamicExemplars;
public ChatQueryContext() {
this(new QueryNLReq());
}
public ChatQueryContext(QueryNLReq request) {
this.request = request;
}
public List<SemanticQuery> getCandidateQueries() {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);

View File

@@ -61,8 +61,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
ChatApp chatApp = chatQueryContext.getChatAppConfig().get(APP_KEY);
if (!chatQueryContext.getText2SQLType().enableLLM() || Objects.isNull(chatApp)
ChatApp chatApp = chatQueryContext.getRequest().getChatAppConfig().get(APP_KEY);
if (!chatQueryContext.getRequest().getText2SQLType().enableLLM() || Objects.isNull(chatApp)
|| !chatApp.isEnable()) {
return;
}
@@ -71,8 +71,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
ModelProvider.getChatModel(chatApp.getChatModelConfig());
SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo,
chatApp.getPrompt());
Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(),
semanticParseInfo, chatApp.getPrompt());
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
keyPipelineLog.info("LLMSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(), s2Sql);
if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) {

View File

@@ -33,7 +33,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
protected void addQueryFilter(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters());
String queryFilter = getQueryFilter(chatQueryContext.getRequest().getQueryFilters());
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
if (StringUtils.isNotEmpty(queryFilter)) {

View File

@@ -116,12 +116,12 @@ public abstract class BaseMapper implements SchemaMapper {
public <T> List<T> getMatches(ChatQueryContext chatQueryContext,
BaseMatchStrategy matchStrategy) {
String queryText = chatQueryContext.getQueryText();
String queryText = chatQueryContext.getRequest().getQueryText();
List<S2Term> terms =
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
terms = HanlpHelper.getTerms(terms, chatQueryContext.getDataSetIds());
Map<MatchText, List<T>> matchResult =
matchStrategy.match(chatQueryContext, terms, chatQueryContext.getDataSetIds());
terms = HanlpHelper.getTerms(terms, chatQueryContext.getRequest().getDataSetIds());
Map<MatchText, List<T>> matchResult = matchStrategy.match(chatQueryContext, terms,
chatQueryContext.getRequest().getDataSetIds());
List<T> matches = new ArrayList<>();
if (Objects.isNull(matchResult)) {
return matches;

View File

@@ -21,7 +21,7 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
@Override
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
String text = chatQueryContext.getQueryText();
String text = chatQueryContext.getRequest().getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}

View File

@@ -22,7 +22,7 @@ public abstract class BatchMatchStrategy<T extends MapResult> extends BaseMatchS
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
String text = chatQueryContext.getQueryText();
String text = chatQueryContext.getRequest().getQueryText();
Set<String> detectSegments = new HashSet<>();
int embeddingTextSize = Integer

View File

@@ -93,7 +93,8 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
log.debug("ModelElementMatches:{},not exist Element threshold reduce by half:{}",
modelElementMatches, threshold);
}
return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum());
return getThreshold(threshold, minThreshold,
chatQueryContext.getRequest().getMapModeEnum());
}
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {

View File

@@ -69,7 +69,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
double embeddingThresholdMin =
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin,
chatQueryContext.getMapModeEnum());
chatQueryContext.getRequest().getMapModeEnum());
// step1. build query params
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();

View File

@@ -105,6 +105,7 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy<HanlpMapResult>
mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD_MIN));
}
return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum());
return getThreshold(threshold, minThreshold,
chatQueryContext.getRequest().getMapModeEnum());
}
}

View File

@@ -32,7 +32,7 @@ public class KeywordMapper extends BaseMapper {
@Override
public void doMap(ChatQueryContext chatQueryContext) {
String queryText = chatQueryContext.getQueryText();
String queryText = chatQueryContext.getRequest().getQueryText();
// 1.hanlpDict Match
List<S2Term> terms =
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());

View File

@@ -22,7 +22,7 @@ public class MapFilter {
filterByDataSetId(chatQueryContext);
filterByDetectWordLenLessThanOne(chatQueryContext);
twoCharactersMustEqual(chatQueryContext);
switch (chatQueryContext.getQueryDataType()) {
switch (chatQueryContext.getRequest().getQueryDataType()) {
case TAG:
filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0));
break;
@@ -46,7 +46,7 @@ public class MapFilter {
}
public static void filterByDataSetId(ChatQueryContext chatQueryContext) {
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
Set<Long> dataSetIds = chatQueryContext.getRequest().getDataSetIds();
if (CollectionUtils.isEmpty(dataSetIds)) {
return;
}

View File

@@ -25,7 +25,7 @@ public class QueryFilterMapper extends BaseMapper {
@Override
public void doMap(ChatQueryContext chatQueryContext) {
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
Set<Long> dataSetIds = chatQueryContext.getRequest().getDataSetIds();
if (CollectionUtils.isEmpty(dataSetIds)) {
return;
}
@@ -53,7 +53,7 @@ public class QueryFilterMapper extends BaseMapper {
private void addValueSchemaElementMatch(Long dataSetId, ChatQueryContext chatQueryContext,
List<SchemaElementMatch> candidateElementMatches) {
QueryFilters queryFilters = chatQueryContext.getQueryFilters();
QueryFilters queryFilters = chatQueryContext.getRequest().getQueryFilters();
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return;
}

View File

@@ -36,7 +36,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Override
public Map<MatchText, List<HanlpMapResult>> match(ChatQueryContext chatQueryContext,
List<S2Term> originals, Set<Long> detectDataSetIds) {
String text = chatQueryContext.getQueryText();
String text = chatQueryContext.getRequest().getQueryText();
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(originals);
List<Integer> detectIndexList = Lists.newArrayList();

View File

@@ -24,7 +24,7 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
String text = chatQueryContext.getQueryText();
String text = chatQueryContext.getRequest().getQueryText();
Set<T> results = new HashSet<>();
Set<String> detectSegments = new HashSet<>();

View File

@@ -20,21 +20,22 @@ public class TermDescMapper extends BaseMapper {
return;
}
if (StringUtils.isBlank(chatQueryContext.getOriQueryText())) {
chatQueryContext.setOriQueryText(chatQueryContext.getQueryText());
chatQueryContext.setOriQueryText(chatQueryContext.getRequest().getQueryText());
}
for (SchemaElement schemaElement : termDescriptionToMap) {
if (schemaElement.isDescriptionMapped()) {
continue;
}
if (chatQueryContext.getQueryText().equals(schemaElement.getDescription())) {
if (chatQueryContext.getRequest().getQueryText()
.equals(schemaElement.getDescription())) {
schemaElement.setDescriptionMapped(true);
continue;
}
chatQueryContext.setQueryText(schemaElement.getDescription());
chatQueryContext.getRequest().setQueryText(schemaElement.getDescription());
break;
}
if (CollectionUtils.isEmpty(chatQueryContext.getMapInfo().getTermDescriptionToMap())) {
chatQueryContext.setQueryText(chatQueryContext.getOriQueryText());
chatQueryContext.getRequest().setQueryText(chatQueryContext.getOriQueryText());
}
}
}

View File

@@ -31,7 +31,7 @@ public class QueryTypeParser implements SemanticParser {
public void parse(ChatQueryContext chatQueryContext) {
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
User user = chatQueryContext.getUser();
User user = chatQueryContext.getRequest().getUser();
for (SemanticQuery semanticQuery : candidateQueries) {
// 1.init S2SQL

View File

@@ -25,7 +25,8 @@ public class SatisfactionChecker {
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
continue;
}
if (checkThreshold(chatQueryContext.getQueryText(), query.getParseInfo())) {
if (checkThreshold(chatQueryContext.getRequest().getQueryText(),
query.getParseInfo())) {
return true;
}
}

View File

@@ -35,7 +35,7 @@ public class LLMRequestService {
private ParserConfig parserConfig;
public boolean isSkip(ChatQueryContext queryCtx) {
if (!queryCtx.getText2SQLType().enableLLM()) {
if (!queryCtx.getRequest().getText2SQLType().enableLLM()) {
log.info("LLM disabled, skip");
return true;
}
@@ -45,12 +45,12 @@ public class LLMRequestService {
public Long getDataSetId(ChatQueryContext queryCtx) {
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds());
return dataSetResolver.resolve(queryCtx, queryCtx.getRequest().getDataSetIds());
}
public LLMReq getLlmReq(ChatQueryContext queryCtx, Long dataSetId) {
Map<Long, String> dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName();
String queryText = queryCtx.getQueryText();
String queryText = queryCtx.getRequest().getQueryText();
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);
@@ -74,8 +74,8 @@ public class LLMRequestService {
llmReq.setTerms(getMappedTerms(queryCtx, dataSetId));
llmReq.setSqlGenType(
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setChatAppConfig(queryCtx.getChatAppConfig());
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
llmReq.setChatAppConfig(queryCtx.getRequest().getChatAppConfig());
llmReq.setDynamicExemplars(queryCtx.getRequest().getDynamicExemplars());
return llmReq;
}

View File

@@ -39,13 +39,14 @@ public class LLMResponseService {
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, parseResult);
properties.put("type", "internal");
Text2SQLExemplar exemplar = Text2SQLExemplar.builder().question(queryCtx.getQueryText())
.sideInfo(parseResult.getLlmResp().getSideInfo())
.dbSchema(parseResult.getLlmResp().getSchema())
.sql(parseResult.getLlmResp().getSqlOutput()).build();
Text2SQLExemplar exemplar =
Text2SQLExemplar.builder().question(queryCtx.getRequest().getQueryText())
.sideInfo(parseResult.getLlmResp().getSideInfo())
.dbSchema(parseResult.getLlmResp().getSchema())
.sql(parseResult.getLlmResp().getSqlOutput()).build();
properties.put(Text2SQLExemplar.PROPERTY_KEY, exemplar);
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setParsedS2SQL(s2SQL);
queryCtx.getCandidateQueries().add(semanticQuery);

View File

@@ -44,7 +44,7 @@ public class AggregateTypeParser implements SemanticParser {
@Override
public void parse(ChatQueryContext chatQueryContext) {
String queryText = chatQueryContext.getQueryText();
String queryText = chatQueryContext.getRequest().getQueryText();
AggregateConf aggregateConf = resolveAggregateConf(queryText);
for (SemanticQuery semanticQuery : chatQueryContext.getCandidateQueries()) {

View File

@@ -60,12 +60,12 @@ public class ContextInheritParser implements SemanticParser {
chatQueryContext.getMapInfo().getMatchedElements(dataSetId);
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
for (SchemaElementMatch match : chatQueryContext.getContextParseInfo()
for (SchemaElementMatch match : chatQueryContext.getRequest().getContextParseInfo()
.getElementMatches()) {
SchemaElementType matchType = match.getElement().getType();
// mutual exclusive element types should not be inherited
RuleSemanticQuery ruleQuery = QueryManager
.getRuleQuery(chatQueryContext.getContextParseInfo().getQueryMode());
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(
chatQueryContext.getRequest().getContextParseInfo().getQueryMode());
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
match.setInherited(true);
matchesToInherit.add(match);
@@ -121,10 +121,13 @@ public class ContextInheritParser implements SemanticParser {
}
protected Long getMatchedDataSet(ChatQueryContext chatQueryContext) {
Long dataSetId = chatQueryContext.getContextParseInfo().getDataSetId();
if (dataSetId == null) {
if (Objects.isNull(chatQueryContext)
|| Objects.isNull(chatQueryContext.getRequest().getContextParseInfo())
|| Objects.isNull(
chatQueryContext.getRequest().getContextParseInfo().getDataSetId())) {
return null;
}
Long dataSetId = chatQueryContext.getRequest().getContextParseInfo().getDataSetId();
Set<Long> queryDataSets = chatQueryContext.getMapInfo().getMatchedDataSetInfos();
if (queryDataSets.contains(dataSetId)) {
return dataSetId;

View File

@@ -22,7 +22,7 @@ public class RuleSqlParser implements SemanticParser {
@Override
public void parse(ChatQueryContext chatQueryContext) {
if (!chatQueryContext.getText2SQLType().enableRule()
if (!chatQueryContext.getRequest().getText2SQLType().enableRule()
|| !chatQueryContext.getCandidateQueries().isEmpty()) {
return;
}

View File

@@ -38,7 +38,7 @@ public class TimeRangeParser implements SemanticParser {
@Override
public void parse(ChatQueryContext queryContext) {
String queryText = queryContext.getQueryText();
String queryText = queryContext.getRequest().getQueryText();
DateConf dateConf = parseRecent(queryText);
if (dateConf == null) {
dateConf = parseDateNumber(queryText);
@@ -62,7 +62,7 @@ public class TimeRangeParser implements SemanticParser {
parseInfo.setScore(parseInfo.getScore() + dateConf.getDetectWord().length());
}
} else {
SemanticParseInfo contextParseInfo = queryContext.getContextParseInfo();
SemanticParseInfo contextParseInfo = queryContext.getRequest().getContextParseInfo();
if (QueryManager.containsRuleQuery(contextParseInfo.getQueryMode())) {
RuleSemanticQuery semanticQuery =
QueryManager.createRuleQuery(contextParseInfo.getQueryMode());

View File

@@ -69,8 +69,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
private void fillDateConfByInherited(SemanticParseInfo queryParseInfo,
ChatQueryContext chatQueryContext) {
SemanticParseInfo contextParseInfo = chatQueryContext.getContextParseInfo();
if (queryParseInfo.getDateInfo() != null || contextParseInfo.getDateInfo() == null
SemanticParseInfo contextParseInfo = chatQueryContext.getRequest().getContextParseInfo();
if (queryParseInfo.getDateInfo() != null || Objects.isNull(contextParseInfo)
|| Objects.isNull(contextParseInfo.getDateInfo())
|| needFillDateConf(chatQueryContext)) {
return;
}

View File

@@ -35,7 +35,7 @@ public class MetricTopNQuery extends MetricSemanticQuery {
@Override
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
ChatQueryContext queryCtx) {
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getQueryText());
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getRequest().getQueryText());
if (matcher.matches()) {
return super.match(candidateElementMatches, queryCtx);
}

View File

@@ -48,7 +48,7 @@ class WhereCorrectorTest {
queryFilters.getFilters().add(filter2);
queryFilters.getFilters().add(filter3);
queryFilters.getFilters().add(filter4);
chatQueryContext.setQueryFilters(queryFilters);
chatQueryContext.getRequest().setQueryFilters(queryFilters);
WhereCorrector whereCorrector = new WhereCorrector();
whereCorrector.addQueryFilter(chatQueryContext, semanticParseInfo);

View File

@@ -93,19 +93,6 @@ public class S2ChatLayerService implements ChatLayerService {
return parseResult;
}
private ChatQueryContext buildChatQueryContext(QueryNLReq queryNLReq) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema(queryNLReq.getDataSetIds());
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds();
ChatQueryContext queryCtx = ChatQueryContext.builder()
.queryFilters(queryNLReq.getQueryFilters()).semanticSchema(semanticSchema)
.candidateQueries(new ArrayList<>()).mapInfo(new SchemaMapInfo())
.modelIdToDataSetIds(modelIdToDataSetIds).text2SQLType(queryNLReq.getText2SQLType())
.mapModeEnum(queryNLReq.getMapModeEnum()).dataSetIds(queryNLReq.getDataSetIds())
.build();
BeanUtils.copyProperties(queryNLReq, queryCtx);
return queryCtx;
}
public void correct(QuerySqlReq querySqlReq, User user) {
SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user);
querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
@@ -122,6 +109,15 @@ public class S2ChatLayerService implements ChatLayerService {
return retrieveService.retrieve(queryNLReq);
}
private ChatQueryContext buildChatQueryContext(QueryNLReq queryNLReq) {
ChatQueryContext queryCtx = new ChatQueryContext(queryNLReq);
SemanticSchema semanticSchema = schemaService.getSemanticSchema(queryNLReq.getDataSetIds());
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds();
queryCtx.setSemanticSchema(semanticSchema);
queryCtx.setModelIdToDataSetIds(modelIdToDataSetIds);
return queryCtx;
}
private SemanticParseInfo correctSqlReq(QuerySqlReq querySqlReq, User user) {
ChatQueryContext queryCtx = new ChatQueryContext();
SemanticSchema semanticSchema =

View File

@@ -26,7 +26,7 @@ public class EntityInfoProcessor implements ResultProcessor {
DataSetSchema dataSetSchema =
semanticService.getDataSetSchema(parseInfo.getDataSetId());
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema,
chatQueryContext.getUser());
chatQueryContext.getRequest().getUser());
parseInfo.setEntityInfo(entityInfo);
});
}

View File

@@ -24,7 +24,6 @@ import com.tencent.supersonic.headless.server.service.RetrieveService;
import com.tencent.supersonic.headless.server.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@@ -75,8 +74,7 @@ public class RetrieveServiceImpl implements RetrieveService {
log.debug("originals terms: {}", originals);
Set<Long> dataSetIds = queryNLReq.getDataSetIds();
ChatQueryContext chatQueryContext = new ChatQueryContext();
BeanUtils.copyProperties(queryNLReq, chatQueryContext);
ChatQueryContext chatQueryContext = new ChatQueryContext(queryNLReq);
chatQueryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds());
Map<MatchText, List<HanlpMapResult>> regTextMap =

View File

@@ -140,7 +140,7 @@ public class ChatWorkflowEngine {
SemanticLayerService queryService =
ContextUtils.getBean(SemanticLayerService.class);
SemanticTranslateResp explain =
queryService.translate(semanticQueryReq, queryCtx.getUser());
queryService.translate(semanticQueryReq, queryCtx.getRequest().getUser());
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
if (StringUtils.isNotBlank(explain.getErrMsg())) {
errorMsg.add(explain.getErrMsg());

View File

@@ -48,7 +48,6 @@ public class MultiTurnsTest extends BaseTest {
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);