diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java index c94da6d0a..6172761a4 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java @@ -7,7 +7,6 @@ import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.headless.api.pojo.QueryDataType; -import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import lombok.Data; @@ -17,7 +16,7 @@ import java.util.Map; import java.util.Set; @Data -public class QueryNLReq { +public class QueryNLReq extends SemanticQueryReq { private String queryText; private Set dataSetIds = Sets.newHashSet(); private User user; @@ -25,9 +24,13 @@ public class QueryNLReq { private boolean saveAnswer = true; private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; private MapModeEnum mapModeEnum = MapModeEnum.STRICT; - private SchemaMapInfo mapInfo = new SchemaMapInfo(); private QueryDataType queryDataType = QueryDataType.ALL; private Map chatAppConfig; private List dynamicExemplars = Lists.newArrayList(); private SemanticParseInfo contextParseInfo; + + @Override + public String toCustomizedString() { + return ""; + } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/SemanticQueryReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/SemanticQueryReq.java index 41fb26d6e..e1aa05a02 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/SemanticQueryReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/SemanticQueryReq.java @@ -50,12 +50,4 @@ public abstract class SemanticQueryReq { public Set getModelIdSet() { return modelIds; } - - public boolean isNeedAuth() { - return needAuth; - } - - public void setNeedAuth(boolean needAuth) { - this.needAuth = needAuth; - } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index bbc84ce07..a5601f7fa 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -1,63 +1,41 @@ package com.tencent.supersonic.headless.chat; import com.fasterxml.jackson.annotation.JsonIgnore; -import com.tencent.supersonic.common.pojo.ChatApp; -import com.tencent.supersonic.common.pojo.Text2SQLExemplar; -import com.tencent.supersonic.common.pojo.User; -import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; -import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; -import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState; -import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; -import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; +import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.query.SemanticQuery; -import lombok.AllArgsConstructor; -import lombok.Builder; import lombok.Data; -import lombok.NoArgsConstructor; import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; @Data -@Builder -@NoArgsConstructor -@AllArgsConstructor public class ChatQueryContext { - private String queryText; + private QueryNLReq request; private String oriQueryText; - private Set dataSetIds; private Map> modelIdToDataSetIds; - private User user; - private boolean saveAnswer; - private QueryFilters queryFilters; private List candidateQueries = new ArrayList<>(); private SchemaMapInfo mapInfo = new SchemaMapInfo(); - private SemanticParseInfo contextParseInfo; - @Builder.Default - private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; - @Builder.Default - private MapModeEnum mapModeEnum = MapModeEnum.STRICT; - @Builder.Default - private QueryDataType queryDataType = QueryDataType.ALL; @JsonIgnore private SemanticSchema semanticSchema; - @JsonIgnore private ChatWorkflowState chatWorkflowState; - @JsonIgnore - private Map chatAppConfig; - @JsonIgnore - private List dynamicExemplars; + + public ChatQueryContext() { + this(new QueryNLReq()); + } + + public ChatQueryContext(QueryNLReq request) { + this.request = request; + } public List getCandidateQueries() { ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java index c5dbb0f28..5f8335cc8 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java @@ -61,8 +61,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector { @Override public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { - ChatApp chatApp = chatQueryContext.getChatAppConfig().get(APP_KEY); - if (!chatQueryContext.getText2SQLType().enableLLM() || Objects.isNull(chatApp) + ChatApp chatApp = chatQueryContext.getRequest().getChatAppConfig().get(APP_KEY); + if (!chatQueryContext.getRequest().getText2SQLType().enableLLM() || Objects.isNull(chatApp) || !chatApp.isEnable()) { return; } @@ -71,8 +71,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector { ModelProvider.getChatModel(chatApp.getChatModelConfig()); SemanticSqlExtractor extractor = AiServices.create(SemanticSqlExtractor.class, chatLanguageModel); - Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo, - chatApp.getPrompt()); + Prompt prompt = generatePrompt(chatQueryContext.getRequest().getQueryText(), + semanticParseInfo, chatApp.getPrompt()); SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText()); keyPipelineLog.info("LLMSqlCorrector modelReq:\n{} \nmodelResp:\n{}", prompt.text(), s2Sql); if ("NEGATIVE".equals(s2Sql.getOpinion()) && StringUtils.isNotBlank(s2Sql.getSql())) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java index 6bae88e39..e411b7e03 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java @@ -33,7 +33,7 @@ public class WhereCorrector extends BaseSemanticCorrector { protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { - String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters()); + String queryFilter = getQueryFilter(chatQueryContext.getRequest().getQueryFilters()); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); if (StringUtils.isNotEmpty(queryFilter)) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java index 4198a2424..f3cc64f9f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java @@ -116,12 +116,12 @@ public abstract class BaseMapper implements SchemaMapper { public List getMatches(ChatQueryContext chatQueryContext, BaseMatchStrategy matchStrategy) { - String queryText = chatQueryContext.getQueryText(); + String queryText = chatQueryContext.getRequest().getQueryText(); List terms = HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds()); - terms = HanlpHelper.getTerms(terms, chatQueryContext.getDataSetIds()); - Map> matchResult = - matchStrategy.match(chatQueryContext, terms, chatQueryContext.getDataSetIds()); + terms = HanlpHelper.getTerms(terms, chatQueryContext.getRequest().getDataSetIds()); + Map> matchResult = matchStrategy.match(chatQueryContext, terms, + chatQueryContext.getRequest().getDataSetIds()); List matches = new ArrayList<>(); if (Objects.isNull(matchResult)) { return matches; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java index 2aebb83f4..1fbc754d2 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java @@ -21,7 +21,7 @@ public abstract class BaseMatchStrategy implements MatchStr @Override public Map> match(ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { - String text = chatQueryContext.getQueryText(); + String text = chatQueryContext.getRequest().getQueryText(); if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { return null; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BatchMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BatchMatchStrategy.java index 7f6a58e38..2da8b624d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BatchMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BatchMatchStrategy.java @@ -22,7 +22,7 @@ public abstract class BatchMatchStrategy extends BaseMatchS public List detect(ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { - String text = chatQueryContext.getQueryText(); + String text = chatQueryContext.getRequest().getQueryText(); Set detectSegments = new HashSet<>(); int embeddingTextSize = Integer diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java index eaaf662d7..d1a406b5f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java @@ -93,7 +93,8 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy> getNameToItems(List models) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java index ccef43890..ee313e193 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java @@ -69,7 +69,7 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN)); double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, - chatQueryContext.getMapModeEnum()); + chatQueryContext.getRequest().getMapModeEnum()); // step1. build query params RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java index 9a680f269..139c23351 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java @@ -105,6 +105,7 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD_MIN)); } - return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum()); + return getThreshold(threshold, minThreshold, + chatQueryContext.getRequest().getMapModeEnum()); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java index 85df0b970..0af886413 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java @@ -32,7 +32,7 @@ public class KeywordMapper extends BaseMapper { @Override public void doMap(ChatQueryContext chatQueryContext) { - String queryText = chatQueryContext.getQueryText(); + String queryText = chatQueryContext.getRequest().getQueryText(); // 1.hanlpDict Match List terms = HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java index dc02fbccf..02fa9691b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java @@ -22,7 +22,7 @@ public class MapFilter { filterByDataSetId(chatQueryContext); filterByDetectWordLenLessThanOne(chatQueryContext); twoCharactersMustEqual(chatQueryContext); - switch (chatQueryContext.getQueryDataType()) { + switch (chatQueryContext.getRequest().getQueryDataType()) { case TAG: filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0)); break; @@ -46,7 +46,7 @@ public class MapFilter { } public static void filterByDataSetId(ChatQueryContext chatQueryContext) { - Set dataSetIds = chatQueryContext.getDataSetIds(); + Set dataSetIds = chatQueryContext.getRequest().getDataSetIds(); if (CollectionUtils.isEmpty(dataSetIds)) { return; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/QueryFilterMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/QueryFilterMapper.java index e9c168207..8065e2561 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/QueryFilterMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/QueryFilterMapper.java @@ -25,7 +25,7 @@ public class QueryFilterMapper extends BaseMapper { @Override public void doMap(ChatQueryContext chatQueryContext) { - Set dataSetIds = chatQueryContext.getDataSetIds(); + Set dataSetIds = chatQueryContext.getRequest().getDataSetIds(); if (CollectionUtils.isEmpty(dataSetIds)) { return; } @@ -53,7 +53,7 @@ public class QueryFilterMapper extends BaseMapper { private void addValueSchemaElementMatch(Long dataSetId, ChatQueryContext chatQueryContext, List candidateElementMatches) { - QueryFilters queryFilters = chatQueryContext.getQueryFilters(); + QueryFilters queryFilters = chatQueryContext.getRequest().getQueryFilters(); if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) { return; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java index 913c0e619..b076305dc 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java @@ -36,7 +36,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy { @Override public Map> match(ChatQueryContext chatQueryContext, List originals, Set detectDataSetIds) { - String text = chatQueryContext.getQueryText(); + String text = chatQueryContext.getRequest().getQueryText(); Map regOffsetToLength = mapperHelper.getRegOffsetToLength(originals); List detectIndexList = Lists.newArrayList(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java index 9b3077ce9..0de3df0a1 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java @@ -24,7 +24,7 @@ public abstract class SingleMatchStrategy extends BaseMatch public List detect(ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { Map regOffsetToLength = mapperHelper.getRegOffsetToLength(terms); - String text = chatQueryContext.getQueryText(); + String text = chatQueryContext.getRequest().getQueryText(); Set results = new HashSet<>(); Set detectSegments = new HashSet<>(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/TermDescMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/TermDescMapper.java index 49488e8d2..e37136327 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/TermDescMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/TermDescMapper.java @@ -20,21 +20,22 @@ public class TermDescMapper extends BaseMapper { return; } if (StringUtils.isBlank(chatQueryContext.getOriQueryText())) { - chatQueryContext.setOriQueryText(chatQueryContext.getQueryText()); + chatQueryContext.setOriQueryText(chatQueryContext.getRequest().getQueryText()); } for (SchemaElement schemaElement : termDescriptionToMap) { if (schemaElement.isDescriptionMapped()) { continue; } - if (chatQueryContext.getQueryText().equals(schemaElement.getDescription())) { + if (chatQueryContext.getRequest().getQueryText() + .equals(schemaElement.getDescription())) { schemaElement.setDescriptionMapped(true); continue; } - chatQueryContext.setQueryText(schemaElement.getDescription()); + chatQueryContext.getRequest().setQueryText(schemaElement.getDescription()); break; } if (CollectionUtils.isEmpty(chatQueryContext.getMapInfo().getTermDescriptionToMap())) { - chatQueryContext.setQueryText(chatQueryContext.getOriQueryText()); + chatQueryContext.getRequest().setQueryText(chatQueryContext.getOriQueryText()); } } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java index c6b2db543..9bdde7571 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java @@ -31,7 +31,7 @@ public class QueryTypeParser implements SemanticParser { public void parse(ChatQueryContext chatQueryContext) { List candidateQueries = chatQueryContext.getCandidateQueries(); - User user = chatQueryContext.getUser(); + User user = chatQueryContext.getRequest().getUser(); for (SemanticQuery semanticQuery : candidateQueries) { // 1.init S2SQL diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SatisfactionChecker.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SatisfactionChecker.java index 0bfd40e22..8aa3092cb 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SatisfactionChecker.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SatisfactionChecker.java @@ -25,7 +25,8 @@ public class SatisfactionChecker { if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) { continue; } - if (checkThreshold(chatQueryContext.getQueryText(), query.getParseInfo())) { + if (checkThreshold(chatQueryContext.getRequest().getQueryText(), + query.getParseInfo())) { return true; } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 429d02e45..dc6f87405 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -35,7 +35,7 @@ public class LLMRequestService { private ParserConfig parserConfig; public boolean isSkip(ChatQueryContext queryCtx) { - if (!queryCtx.getText2SQLType().enableLLM()) { + if (!queryCtx.getRequest().getText2SQLType().enableLLM()) { log.info("LLM disabled, skip"); return true; } @@ -45,12 +45,12 @@ public class LLMRequestService { public Long getDataSetId(ChatQueryContext queryCtx) { DataSetResolver dataSetResolver = ComponentFactory.getModelResolver(); - return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds()); + return dataSetResolver.resolve(queryCtx, queryCtx.getRequest().getDataSetIds()); } public LLMReq getLlmReq(ChatQueryContext queryCtx, Long dataSetId) { Map dataSetIdToName = queryCtx.getSemanticSchema().getDataSetIdToName(); - String queryText = queryCtx.getQueryText(); + String queryText = queryCtx.getRequest().getQueryText(); LLMReq llmReq = new LLMReq(); llmReq.setQueryText(queryText); @@ -74,8 +74,8 @@ public class LLMRequestService { llmReq.setTerms(getMappedTerms(queryCtx, dataSetId)); llmReq.setSqlGenType( LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE))); - llmReq.setChatAppConfig(queryCtx.getChatAppConfig()); - llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars()); + llmReq.setChatAppConfig(queryCtx.getRequest().getChatAppConfig()); + llmReq.setDynamicExemplars(queryCtx.getRequest().getDynamicExemplars()); return llmReq; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java index b0556747e..9b0aeb554 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java @@ -39,13 +39,14 @@ public class LLMResponseService { Map properties = new HashMap<>(); properties.put(Constants.CONTEXT, parseResult); properties.put("type", "internal"); - Text2SQLExemplar exemplar = Text2SQLExemplar.builder().question(queryCtx.getQueryText()) - .sideInfo(parseResult.getLlmResp().getSideInfo()) - .dbSchema(parseResult.getLlmResp().getSchema()) - .sql(parseResult.getLlmResp().getSqlOutput()).build(); + Text2SQLExemplar exemplar = + Text2SQLExemplar.builder().question(queryCtx.getRequest().getQueryText()) + .sideInfo(parseResult.getLlmResp().getSideInfo()) + .dbSchema(parseResult.getLlmResp().getSchema()) + .sql(parseResult.getLlmResp().getSqlOutput()).build(); properties.put(Text2SQLExemplar.PROPERTY_KEY, exemplar); parseInfo.setProperties(properties); - parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight)); + parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight)); parseInfo.setQueryMode(semanticQuery.getQueryMode()); parseInfo.getSqlInfo().setParsedS2SQL(s2SQL); queryCtx.getCandidateQueries().add(semanticQuery); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java index a5e0a05e1..24676c06b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java @@ -44,7 +44,7 @@ public class AggregateTypeParser implements SemanticParser { @Override public void parse(ChatQueryContext chatQueryContext) { - String queryText = chatQueryContext.getQueryText(); + String queryText = chatQueryContext.getRequest().getQueryText(); AggregateConf aggregateConf = resolveAggregateConf(queryText); for (SemanticQuery semanticQuery : chatQueryContext.getCandidateQueries()) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java index 79e22a17e..0996bccf2 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java @@ -60,12 +60,12 @@ public class ContextInheritParser implements SemanticParser { chatQueryContext.getMapInfo().getMatchedElements(dataSetId); List matchesToInherit = new ArrayList<>(); - for (SchemaElementMatch match : chatQueryContext.getContextParseInfo() + for (SchemaElementMatch match : chatQueryContext.getRequest().getContextParseInfo() .getElementMatches()) { SchemaElementType matchType = match.getElement().getType(); // mutual exclusive element types should not be inherited - RuleSemanticQuery ruleQuery = QueryManager - .getRuleQuery(chatQueryContext.getContextParseInfo().getQueryMode()); + RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery( + chatQueryContext.getRequest().getContextParseInfo().getQueryMode()); if (!containsTypes(elementMatches, matchType, ruleQuery)) { match.setInherited(true); matchesToInherit.add(match); @@ -121,10 +121,13 @@ public class ContextInheritParser implements SemanticParser { } protected Long getMatchedDataSet(ChatQueryContext chatQueryContext) { - Long dataSetId = chatQueryContext.getContextParseInfo().getDataSetId(); - if (dataSetId == null) { + if (Objects.isNull(chatQueryContext) + || Objects.isNull(chatQueryContext.getRequest().getContextParseInfo()) + || Objects.isNull( + chatQueryContext.getRequest().getContextParseInfo().getDataSetId())) { return null; } + Long dataSetId = chatQueryContext.getRequest().getContextParseInfo().getDataSetId(); Set queryDataSets = chatQueryContext.getMapInfo().getMatchedDataSetInfos(); if (queryDataSets.contains(dataSetId)) { return dataSetId; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java index d5b1c3956..79e0408fc 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java @@ -22,7 +22,7 @@ public class RuleSqlParser implements SemanticParser { @Override public void parse(ChatQueryContext chatQueryContext) { - if (!chatQueryContext.getText2SQLType().enableRule() + if (!chatQueryContext.getRequest().getText2SQLType().enableRule() || !chatQueryContext.getCandidateQueries().isEmpty()) { return; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java index 7a56aa5f1..d7ba6466e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java @@ -38,7 +38,7 @@ public class TimeRangeParser implements SemanticParser { @Override public void parse(ChatQueryContext queryContext) { - String queryText = queryContext.getQueryText(); + String queryText = queryContext.getRequest().getQueryText(); DateConf dateConf = parseRecent(queryText); if (dateConf == null) { dateConf = parseDateNumber(queryText); @@ -62,7 +62,7 @@ public class TimeRangeParser implements SemanticParser { parseInfo.setScore(parseInfo.getScore() + dateConf.getDetectWord().length()); } } else { - SemanticParseInfo contextParseInfo = queryContext.getContextParseInfo(); + SemanticParseInfo contextParseInfo = queryContext.getRequest().getContextParseInfo(); if (QueryManager.containsRuleQuery(contextParseInfo.getQueryMode())) { RuleSemanticQuery semanticQuery = QueryManager.createRuleQuery(contextParseInfo.getQueryMode()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java index 47913f282..94aa9678f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java @@ -69,8 +69,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { private void fillDateConfByInherited(SemanticParseInfo queryParseInfo, ChatQueryContext chatQueryContext) { - SemanticParseInfo contextParseInfo = chatQueryContext.getContextParseInfo(); - if (queryParseInfo.getDateInfo() != null || contextParseInfo.getDateInfo() == null + SemanticParseInfo contextParseInfo = chatQueryContext.getRequest().getContextParseInfo(); + if (queryParseInfo.getDateInfo() != null || Objects.isNull(contextParseInfo) + || Objects.isNull(contextParseInfo.getDateInfo()) || needFillDateConf(chatQueryContext)) { return; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java index bb02442e2..10f2daffc 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java @@ -35,7 +35,7 @@ public class MetricTopNQuery extends MetricSemanticQuery { @Override public List match(List candidateElementMatches, ChatQueryContext queryCtx) { - Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getQueryText()); + Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getRequest().getQueryText()); if (matcher.matches()) { return super.match(candidateElementMatches, queryCtx); } diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java index cabd5677e..52c42e3b6 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java @@ -48,7 +48,7 @@ class WhereCorrectorTest { queryFilters.getFilters().add(filter2); queryFilters.getFilters().add(filter3); queryFilters.getFilters().add(filter4); - chatQueryContext.setQueryFilters(queryFilters); + chatQueryContext.getRequest().setQueryFilters(queryFilters); WhereCorrector whereCorrector = new WhereCorrector(); whereCorrector.addQueryFilter(chatQueryContext, semanticParseInfo); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java index d9c0e942f..fa85ef3b2 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java @@ -93,19 +93,6 @@ public class S2ChatLayerService implements ChatLayerService { return parseResult; } - private ChatQueryContext buildChatQueryContext(QueryNLReq queryNLReq) { - SemanticSchema semanticSchema = schemaService.getSemanticSchema(queryNLReq.getDataSetIds()); - Map> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(); - ChatQueryContext queryCtx = ChatQueryContext.builder() - .queryFilters(queryNLReq.getQueryFilters()).semanticSchema(semanticSchema) - .candidateQueries(new ArrayList<>()).mapInfo(new SchemaMapInfo()) - .modelIdToDataSetIds(modelIdToDataSetIds).text2SQLType(queryNLReq.getText2SQLType()) - .mapModeEnum(queryNLReq.getMapModeEnum()).dataSetIds(queryNLReq.getDataSetIds()) - .build(); - BeanUtils.copyProperties(queryNLReq, queryCtx); - return queryCtx; - } - public void correct(QuerySqlReq querySqlReq, User user) { SemanticParseInfo semanticParseInfo = correctSqlReq(querySqlReq, user); querySqlReq.setSql(semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); @@ -122,6 +109,15 @@ public class S2ChatLayerService implements ChatLayerService { return retrieveService.retrieve(queryNLReq); } + private ChatQueryContext buildChatQueryContext(QueryNLReq queryNLReq) { + ChatQueryContext queryCtx = new ChatQueryContext(queryNLReq); + SemanticSchema semanticSchema = schemaService.getSemanticSchema(queryNLReq.getDataSetIds()); + Map> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(); + queryCtx.setSemanticSchema(semanticSchema); + queryCtx.setModelIdToDataSetIds(modelIdToDataSetIds); + return queryCtx; + } + private SemanticParseInfo correctSqlReq(QuerySqlReq querySqlReq, User user) { ChatQueryContext queryCtx = new ChatQueryContext(); SemanticSchema semanticSchema = diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java index 1e81baac2..0bbb73e5d 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java @@ -26,7 +26,7 @@ public class EntityInfoProcessor implements ResultProcessor { DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId()); EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, - chatQueryContext.getUser()); + chatQueryContext.getRequest().getUser()); parseInfo.setEntityInfo(entityInfo); }); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/RetrieveServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/RetrieveServiceImpl.java index 7dec0670f..016705f3a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/RetrieveServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/RetrieveServiceImpl.java @@ -24,7 +24,6 @@ import com.tencent.supersonic.headless.server.service.RetrieveService; import com.tencent.supersonic.headless.server.service.SchemaService; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; -import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; @@ -75,8 +74,7 @@ public class RetrieveServiceImpl implements RetrieveService { log.debug("originals terms: {}", originals); Set dataSetIds = queryNLReq.getDataSetIds(); - ChatQueryContext chatQueryContext = new ChatQueryContext(); - BeanUtils.copyProperties(queryNLReq, chatQueryContext); + ChatQueryContext chatQueryContext = new ChatQueryContext(queryNLReq); chatQueryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds()); Map> regTextMap = diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java index 84c497125..a03a4b6d5 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -140,7 +140,7 @@ public class ChatWorkflowEngine { SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class); SemanticTranslateResp explain = - queryService.translate(semanticQueryReq, queryCtx.getUser()); + queryService.translate(semanticQueryReq, queryCtx.getRequest().getUser()); parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); if (StringUtils.isNotBlank(explain.getErrMsg())) { errorMsg.add(explain.getErrMsg()); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java index 35a27bb41..a5a46403d 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java @@ -48,7 +48,6 @@ public class MultiTurnsTest extends BaseTest { QueryResult expectedResult = new QueryResult(); SemanticParseInfo expectedParseInfo = new SemanticParseInfo(); expectedResult.setChatContext(expectedParseInfo); - expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE);