diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/WorkflowState.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ChatWorkflowState.java similarity index 82% rename from headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/WorkflowState.java rename to headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ChatWorkflowState.java index 3c7fd0a87..27b2ce4cf 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/WorkflowState.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ChatWorkflowState.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.headless.api.pojo.enums; -public enum WorkflowState { +public enum ChatWorkflowState { MAPPING, PARSING, CORRECTING, diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/QueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java similarity index 93% rename from headless/chat/src/main/java/com/tencent/supersonic/headless/chat/QueryContext.java rename to headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index fdd1bd114..127a42e3d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/QueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -4,14 +4,14 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.PromptConfig; -import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; +import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; -import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState; +import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.query.SemanticQuery; @@ -20,6 +20,7 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; + import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -31,7 +32,7 @@ import java.util.stream.Collectors; @Builder @NoArgsConstructor @AllArgsConstructor -public class QueryContext { +public class ChatQueryContext { private String queryText; private Integer chatId; @@ -39,6 +40,7 @@ public class QueryContext { private Map> modelIdToDataSetIds; private User user; private boolean saveAnswer; + @Builder.Default private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; private QueryFilters queryFilters; private List candidateQueries = new ArrayList<>(); @@ -47,7 +49,7 @@ public class QueryContext { @JsonIgnore private SemanticSchema semanticSchema; @JsonIgnore - private WorkflowState workflowState; + private ChatWorkflowState chatWorkflowState; private QueryDataType queryDataType = QueryDataType.ALL; private ModelConfig modelConfig; private PromptConfig promptConfig; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java index cb4eb5e66..3079df3af 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; @@ -16,17 +16,17 @@ import java.util.List; public class AggCorrector extends BaseSemanticCorrector { @Override - public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { - addAggregate(queryContext, semanticParseInfo); + public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + addAggregate(chatQueryContext, semanticParseInfo); } - private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + private void addAggregate(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { List sqlGroupByFields = SqlSelectHelper.getGroupByFields( semanticParseInfo.getSqlInfo().getCorrectS2SQL()); if (CollectionUtils.isEmpty(sqlGroupByFields)) { return; } - addAggregateToMetric(queryContext, semanticParseInfo); + addAggregateToMetric(chatQueryContext, semanticParseInfo); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java index c0ee5c3d3..bca8709c2 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java @@ -6,7 +6,8 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; + import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -26,23 +27,23 @@ import org.springframework.util.CollectionUtils; @Slf4j public abstract class BaseSemanticCorrector implements SemanticCorrector { - public void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { try { if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) { return; } - doCorrect(queryContext, semanticParseInfo); + doCorrect(chatQueryContext, semanticParseInfo); log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo()); } catch (Exception e) { log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e); } } - public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo); + public abstract void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo); - protected Map getFieldNameMap(QueryContext queryContext, Long dataSetId) { + protected Map getFieldNameMap(ChatQueryContext chatQueryContext, Long dataSetId) { - SemanticSchema semanticSchema = queryContext.getSemanticSchema(); + SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); List dbAllFields = new ArrayList<>(); dbAllFields.addAll(semanticSchema.getMetrics()); @@ -71,11 +72,11 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { return result; } - protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { //add aggregate to all metric String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); Long dataSetId = semanticParseInfo.getDataSet().getDataSet(); - List metrics = getMetricElements(queryContext, dataSetId); + List metrics = getMetricElements(chatQueryContext, dataSetId); Map metricToAggregate = metrics.stream() .map(schemaElement -> { @@ -100,8 +101,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql); } - protected List getMetricElements(QueryContext queryContext, Long dataSetId) { - SemanticSchema semanticSchema = queryContext.getSemanticSchema(); + protected List getMetricElements(ChatQueryContext chatQueryContext, Long dataSetId) { + SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); return semanticSchema.getMetrics(dataSetId); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GrammarCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GrammarCorrector.java index 3240eb09c..5bfd09400 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GrammarCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GrammarCorrector.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import java.util.ArrayList; @@ -26,9 +26,9 @@ public class GrammarCorrector extends BaseSemanticCorrector { } @Override - public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { for (BaseSemanticCorrector corrector : correctors) { - corrector.correct(queryContext, semanticParseInfo); + corrector.correct(chatQueryContext, semanticParseInfo); } removeSameFieldFromSelect(semanticParseInfo); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java index e0ccc96ca..539044af6 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java @@ -7,7 +7,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SqlInfo; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.core.env.Environment; @@ -23,20 +23,20 @@ import java.util.stream.Collectors; public class GroupByCorrector extends BaseSemanticCorrector { @Override - public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { - Boolean needAddGroupBy = needAddGroupBy(queryContext, semanticParseInfo); + public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + Boolean needAddGroupBy = needAddGroupBy(chatQueryContext, semanticParseInfo); if (!needAddGroupBy) { return; } - addGroupByFields(queryContext, semanticParseInfo); + addGroupByFields(chatQueryContext, semanticParseInfo); } - private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + private Boolean needAddGroupBy(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { Long dataSetId = semanticParseInfo.getDataSetId(); //add dimension group by SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); String correctS2SQL = sqlInfo.getCorrectS2SQL(); - SemanticSchema semanticSchema = queryContext.getSemanticSchema(); + SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); // check has distinct if (SqlSelectHelper.hasDistinct(correctS2SQL)) { log.debug("no need to add groupby ,existed distinct in s2sql:{}", correctS2SQL); @@ -64,12 +64,12 @@ public class GroupByCorrector extends BaseSemanticCorrector { return true; } - private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + private void addGroupByFields(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { Long dataSetId = semanticParseInfo.getDataSetId(); //add dimension group by SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); String correctS2SQL = sqlInfo.getCorrectS2SQL(); - SemanticSchema semanticSchema = queryContext.getSemanticSchema(); + SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); //add alias field name Set dimensions = getDimensions(dataSetId, semanticSchema); List selectFields = SqlSelectHelper.getSelectFields(correctS2SQL); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java index 0c61d49e5..73da6623e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java @@ -6,7 +6,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper; import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; import org.apache.commons.lang3.StringUtils; @@ -24,10 +24,10 @@ import java.util.stream.Collectors; public class HavingCorrector extends BaseSemanticCorrector { @Override - public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { //add aggregate to all metric - addHaving(queryContext, semanticParseInfo); + addHaving(chatQueryContext, semanticParseInfo); //decide whether add having expression field to select Environment environment = ContextUtils.getBean(Environment.class); @@ -38,10 +38,10 @@ public class HavingCorrector extends BaseSemanticCorrector { } - private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + private void addHaving(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { Long dataSet = semanticParseInfo.getDataSet().getDataSet(); - SemanticSchema semanticSchema = queryContext.getSemanticSchema(); + SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); Set metrics = semanticSchema.getMetrics(dataSet).stream() .map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java index fa3bc2a9f..3431535eb 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java @@ -6,19 +6,19 @@ import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum; import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import org.apache.commons.lang3.tuple.Pair; import java.util.Objects; public class S2SqlDateHelper { - public static String getReferenceDate(QueryContext queryContext, Long dataSetId) { + public static String getReferenceDate(ChatQueryContext chatQueryContext, Long dataSetId) { String defaultDate = DateUtils.getBeforeDate(0); if (Objects.isNull(dataSetId)) { return defaultDate; } - DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); + DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); if (dataSetSchema == null || dataSetSchema.getTagTypeTimeDefaultConfig() == null) { return defaultDate; } @@ -26,13 +26,13 @@ public class S2SqlDateHelper { return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig).getLeft(); } - public static Pair getStartEndDate(QueryContext queryContext, Long dataSetId, - QueryType queryType) { + public static Pair getStartEndDate(ChatQueryContext chatQueryContext, Long dataSetId, + QueryType queryType) { String defaultDate = DateUtils.getBeforeDate(0); if (Objects.isNull(dataSetId)) { return Pair.of(defaultDate, defaultDate); } - DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); + DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); if (dataSetSchema == null) { return Pair.of(defaultDate, defaultDate); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java index d1d7ca464..ad80a6b11 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java @@ -13,7 +13,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SqlInfo; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.parser.llm.ParseResult; import lombok.extern.slf4j.Slf4j; @@ -34,7 +34,7 @@ import java.util.stream.Collectors; public class SchemaCorrector extends BaseSemanticCorrector { @Override - public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { correctAggFunction(semanticParseInfo); @@ -44,7 +44,7 @@ public class SchemaCorrector extends BaseSemanticCorrector { updateFieldValueByLinkingValue(semanticParseInfo); - correctFieldName(queryContext, semanticParseInfo); + correctFieldName(chatQueryContext, semanticParseInfo); } private void correctAggFunction(SemanticParseInfo semanticParseInfo) { @@ -60,8 +60,8 @@ public class SchemaCorrector extends BaseSemanticCorrector { sqlInfo.setCorrectS2SQL(replaceAlias); } - private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { - Map fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getDataSetId()); + private void correctFieldName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + Map fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId()); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap); sqlInfo.setCorrectS2SQL(sql); @@ -115,7 +115,8 @@ public class SchemaCorrector extends BaseSemanticCorrector { sqlInfo.setCorrectS2SQL(sql); } - public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); String correctS2SQL = sqlInfo.getCorrectS2SQL(); List whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL); @@ -123,7 +124,7 @@ public class SchemaCorrector extends BaseSemanticCorrector { return; } List linkingValues = getLinkingValues(semanticParseInfo); - SemanticSchema semanticSchema = queryContext.getSemanticSchema(); + SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); Set dimensions = getDimensions(semanticParseInfo.getDataSetId(), semanticSchema); if (CollectionUtils.isEmpty(linkingValues)) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java index 87826a8db..c48e2e7b5 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java @@ -10,7 +10,7 @@ import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.core.env.Environment; @@ -32,7 +32,7 @@ public class SelectCorrector extends BaseSemanticCorrector { public static final String ADDITIONAL_INFORMATION = "s2.corrector.additional.information"; @Override - public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); List aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL); List selectFields = SqlSelectHelper.getSelectFields(correctS2SQL); @@ -42,14 +42,14 @@ public class SelectCorrector extends BaseSemanticCorrector { && aggregateFields.size() == selectFields.size()) { return; } - correctS2SQL = addFieldsToSelect(queryContext, semanticParseInfo, correctS2SQL); + correctS2SQL = addFieldsToSelect(chatQueryContext, semanticParseInfo, correctS2SQL); String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql); } - protected String addFieldsToSelect(QueryContext queryContext, SemanticParseInfo semanticParseInfo, + protected String addFieldsToSelect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo, String correctS2SQL) { - correctS2SQL = addTagDefaultFields(queryContext, semanticParseInfo, correctS2SQL); + correctS2SQL = addTagDefaultFields(chatQueryContext, semanticParseInfo, correctS2SQL); Set selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL)); Set needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL)); @@ -69,7 +69,7 @@ public class SelectCorrector extends BaseSemanticCorrector { return addFieldsToSelectSql; } - private String addTagDefaultFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo, + private String addTagDefaultFields(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo, String correctS2SQL) { //If it is in DETAIL mode and select *, add default metrics and dimensions. boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL); @@ -77,7 +77,7 @@ public class SelectCorrector extends BaseSemanticCorrector { return correctS2SQL; } Long dataSetId = semanticParseInfo.getDataSetId(); - DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); + DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); Set needAddDefaultFields = new HashSet<>(); if (Objects.nonNull(dataSetSchema)) { if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultMetrics())) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SemanticCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SemanticCorrector.java index 212ed0a72..2ddbb68a9 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SemanticCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SemanticCorrector.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; /** * A semantic corrector checks validity of extracted semantic information and @@ -10,5 +10,5 @@ import com.tencent.supersonic.headless.chat.QueryContext; */ public interface SemanticCorrector { - void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo); + void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java index 230aeb0a3..fc892c3a5 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java @@ -10,7 +10,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper; import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper; import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; @@ -32,11 +32,11 @@ import java.util.Set; public class TimeCorrector extends BaseSemanticCorrector { @Override - public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { - addDateIfNotExist(queryContext, semanticParseInfo); + addDateIfNotExist(chatQueryContext, semanticParseInfo); - removeDateIfExist(queryContext, semanticParseInfo); + removeDateIfExist(chatQueryContext, semanticParseInfo); parserDateDiffFunction(semanticParseInfo); @@ -44,7 +44,7 @@ public class TimeCorrector extends BaseSemanticCorrector { } - private void removeDateIfExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + private void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); //decide whether remove date field from where Environment environment = ContextUtils.getBean(Environment.class); @@ -59,7 +59,7 @@ public class TimeCorrector extends BaseSemanticCorrector { } } - private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); List whereFields = SqlSelectHelper.getWhereFields(correctS2SQL); @@ -71,7 +71,7 @@ public class TimeCorrector extends BaseSemanticCorrector { } if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) { - Pair startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, + Pair startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType()); if (StringUtils.isNotBlank(startEndDate.getLeft()) diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java index 3ae09c4d7..bfd2d94e9 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java @@ -8,7 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaValueMap; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.utils.QueryFilterParser; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.JSQLParserException; @@ -29,15 +29,15 @@ import java.util.Objects; public class WhereCorrector extends BaseSemanticCorrector { @Override - public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { + public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { - addQueryFilter(queryContext, semanticParseInfo); + addQueryFilter(chatQueryContext, semanticParseInfo); - updateFieldValueByTechName(queryContext, semanticParseInfo); + updateFieldValueByTechName(chatQueryContext, semanticParseInfo); } - protected void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { - String queryFilter = getQueryFilter(queryContext.getQueryFilters()); + protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters()); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); @@ -61,8 +61,8 @@ public class WhereCorrector extends BaseSemanticCorrector { return QueryFilterParser.parse(queryFilters); } - private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) { - SemanticSchema semanticSchema = queryContext.getSemanticSchema(); + private void updateFieldValueByTechName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); Long dataSetId = semanticParseInfo.getDataSetId(); List dimensions = semanticSchema.getDimensions(dataSetId); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java index 52a0fdb35..2fc0713c8 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java @@ -6,7 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeanUtils; @@ -26,37 +26,37 @@ import java.util.stream.Collectors; public abstract class BaseMapper implements SchemaMapper { @Override - public void map(QueryContext queryContext) { + public void map(ChatQueryContext chatQueryContext) { String simpleName = this.getClass().getSimpleName(); long startTime = System.currentTimeMillis(); log.debug("before {},mapInfo:{}", simpleName, - queryContext.getMapInfo().getDataSetElementMatches()); + chatQueryContext.getMapInfo().getDataSetElementMatches()); try { - doMap(queryContext); - filter(queryContext); + doMap(chatQueryContext); + filter(chatQueryContext); } catch (Exception e) { log.error("work error", e); } long cost = System.currentTimeMillis() - startTime; log.debug("after {},cost:{},mapInfo:{}", simpleName, cost, - queryContext.getMapInfo().getDataSetElementMatches()); + chatQueryContext.getMapInfo().getDataSetElementMatches()); } - private void filter(QueryContext queryContext) { - filterByDataSetId(queryContext); - filterByDetectWordLenLessThanOne(queryContext); - switch (queryContext.getQueryDataType()) { + private void filter(ChatQueryContext chatQueryContext) { + filterByDataSetId(chatQueryContext); + filterByDetectWordLenLessThanOne(chatQueryContext); + switch (chatQueryContext.getQueryDataType()) { case TAG: - filterByQueryDataType(queryContext, element -> !(element.getIsTag() > 0)); + filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0)); break; case METRIC: - filterByQueryDataType(queryContext, element -> !SchemaElementType.METRIC.equals(element.getType())); + filterByQueryDataType(chatQueryContext, element -> !SchemaElementType.METRIC.equals(element.getType())); break; case DIMENSION: - filterByQueryDataType(queryContext, element -> { + filterByQueryDataType(chatQueryContext, element -> { boolean isDimensionOrValue = SchemaElementType.DIMENSION.equals(element.getType()) || SchemaElementType.VALUE.equals(element.getType()); return !isDimensionOrValue; @@ -68,22 +68,22 @@ public abstract class BaseMapper implements SchemaMapper { } } - private static void filterByDataSetId(QueryContext queryContext) { - Set dataSetIds = queryContext.getDataSetIds(); + private static void filterByDataSetId(ChatQueryContext chatQueryContext) { + Set dataSetIds = chatQueryContext.getDataSetIds(); if (CollectionUtils.isEmpty(dataSetIds)) { return; } - Set dataSetIdInMapInfo = new HashSet<>(queryContext.getMapInfo().getDataSetElementMatches().keySet()); + Set dataSetIdInMapInfo = new HashSet<>(chatQueryContext.getMapInfo().getDataSetElementMatches().keySet()); for (Long dataSetId : dataSetIdInMapInfo) { if (!dataSetIds.contains(dataSetId)) { - queryContext.getMapInfo().getDataSetElementMatches().remove(dataSetId); + chatQueryContext.getMapInfo().getDataSetElementMatches().remove(dataSetId); } } } - private static void filterByDetectWordLenLessThanOne(QueryContext queryContext) { + private static void filterByDetectWordLenLessThanOne(ChatQueryContext chatQueryContext) { Map> dataSetElementMatches = - queryContext.getMapInfo().getDataSetElementMatches(); + chatQueryContext.getMapInfo().getDataSetElementMatches(); for (Map.Entry> entry : dataSetElementMatches.entrySet()) { List value = entry.getValue(); if (!CollectionUtils.isEmpty(value)) { @@ -93,8 +93,9 @@ public abstract class BaseMapper implements SchemaMapper { } } - private static void filterByQueryDataType(QueryContext queryContext, Predicate needRemovePredicate) { - queryContext.getMapInfo().getDataSetElementMatches().values().stream().forEach( + private static void filterByQueryDataType(ChatQueryContext chatQueryContext, + Predicate needRemovePredicate) { + chatQueryContext.getMapInfo().getDataSetElementMatches().values().stream().forEach( schemaElementMatches -> schemaElementMatches.removeIf( schemaElementMatch -> { SchemaElement element = schemaElementMatch.getElement(); @@ -108,7 +109,7 @@ public abstract class BaseMapper implements SchemaMapper { )); } - public abstract void doMap(QueryContext queryContext); + public abstract void doMap(ChatQueryContext chatQueryContext); public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId, SchemaElementMatch newElementMatch) { Map> dataSetElementMatches = schemaMap.getDataSetElementMatches(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java index 461b621ba..38ac2908e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.chat.mapper; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.response.S2Term; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; @@ -33,25 +33,25 @@ public abstract class BaseMatchStrategy implements MatchStrategy { protected MapperConfig mapperConfig; @Override - public Map> match(QueryContext queryContext, List terms, + public Map> match(ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { - String text = queryContext.getQueryText(); + String text = chatQueryContext.getQueryText(); if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { return null; } log.debug("terms:{},,detectDataSetIds:{}", terms, detectDataSetIds); - List detects = detect(queryContext, terms, detectDataSetIds); + List detects = detect(chatQueryContext, terms, detectDataSetIds); Map> result = new HashMap<>(); result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects); return result; } - public List detect(QueryContext queryContext, List terms, Set detectDataSetIds) { + public List detect(ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { Map regOffsetToLength = getRegOffsetToLength(terms); - String text = queryContext.getQueryText(); + String text = chatQueryContext.getQueryText(); Set results = new HashSet<>(); Set detectSegments = new HashSet<>(); @@ -64,16 +64,16 @@ public abstract class BaseMatchStrategy implements MatchStrategy { if (index <= text.length()) { String detectSegment = text.substring(startIndex, index).trim(); detectSegments.add(detectSegment); - detectByStep(queryContext, results, detectDataSetIds, detectSegment, offset); + detectByStep(chatQueryContext, results, detectDataSetIds, detectSegment, offset); } } startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); } - detectByBatch(queryContext, results, detectDataSetIds, detectSegments); + detectByBatch(chatQueryContext, results, detectDataSetIds, detectSegments); return new ArrayList<>(results); } - protected void detectByBatch(QueryContext queryContext, Set results, Set detectDataSetIds, + protected void detectByBatch(ChatQueryContext chatQueryContext, Set results, Set detectDataSetIds, Set detectSegments) { } @@ -108,10 +108,10 @@ public abstract class BaseMatchStrategy implements MatchStrategy { } } - public List getMatches(QueryContext queryContext, List terms) { - Set dataSetIds = queryContext.getDataSetIds(); + public List getMatches(ChatQueryContext chatQueryContext, List terms) { + Set dataSetIds = chatQueryContext.getDataSetIds(); terms = filterByDataSetId(terms, dataSetIds); - Map> matchResult = match(queryContext, terms, dataSetIds); + Map> matchResult = match(chatQueryContext, terms, dataSetIds); List matches = new ArrayList<>(); if (Objects.isNull(matchResult)) { return matches; @@ -155,8 +155,8 @@ public abstract class BaseMatchStrategy implements MatchStrategy { public abstract String getMapKey(T a); - public abstract void detectByStep(QueryContext queryContext, Set existResults, Set detectDataSetIds, - String detectSegment, int offset); + public abstract void detectByStep(ChatQueryContext chatQueryContext, Set existResults, + Set detectDataSetIds, String detectSegment, int offset); public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) { double decreaseAmount = (threshold - minThreshold) / 4; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java index 72eff2092..552cb56c5 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java @@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.response.S2Term; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; @@ -31,10 +31,10 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy private List allElements; @Override - public Map> match(QueryContext queryContext, List terms, + public Map> match(ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { - this.allElements = getSchemaElements(queryContext); - return super.match(queryContext, terms, detectDataSetIds); + this.allElements = getSchemaElements(chatQueryContext); + return super.match(chatQueryContext, terms, detectDataSetIds); } @Override @@ -49,13 +49,13 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy + Constants.UNDERLINE + a.getSchemaElement().getName(); } - public void detectByStep(QueryContext queryContext, Set existResults, Set detectDataSetIds, - String detectSegment, int offset) { + public void detectByStep(ChatQueryContext chatQueryContext, Set existResults, + Set detectDataSetIds, String detectSegment, int offset) { if (StringUtils.isBlank(detectSegment)) { return; } - Double metricDimensionThresholdConfig = getThreshold(queryContext); + Double metricDimensionThresholdConfig = getThreshold(chatQueryContext); Map> nameToItems = getNameToItems(allElements); for (Entry> entry : nameToItems.entrySet()) { @@ -80,18 +80,19 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy } } - private List getSchemaElements(QueryContext queryContext) { + private List getSchemaElements(ChatQueryContext chatQueryContext) { List allElements = new ArrayList<>(); - allElements.addAll(queryContext.getSemanticSchema().getDimensions()); - allElements.addAll(queryContext.getSemanticSchema().getMetrics()); + allElements.addAll(chatQueryContext.getSemanticSchema().getDimensions()); + allElements.addAll(chatQueryContext.getSemanticSchema().getMetrics()); return allElements; } - private Double getThreshold(QueryContext queryContext) { + private Double getThreshold(ChatQueryContext chatQueryContext) { Double threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD)); Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN)); - Map> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches(); + Map> modelElementMatches = chatQueryContext.getMapInfo() + .getDataSetElementMatches(); boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1); @@ -100,7 +101,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy log.debug("ModelElementMatches:{},not exist Element threshold reduce by half:{}", modelElementMatches, threshold); } - return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum()); + return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum()); } private Map> getNameToItems(List models) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java index ed1ca7e58..196343a48 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java @@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder; import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import java.util.List; @@ -23,13 +23,13 @@ import java.util.Objects; public class EmbeddingMapper extends BaseMapper { @Override - public void doMap(QueryContext queryContext) { + public void doMap(ChatQueryContext chatQueryContext) { //1. query from embedding by queryText - String queryText = queryContext.getQueryText(); - List terms = HanlpHelper.getTerms(queryText, queryContext.getModelIdToDataSetIds()); + String queryText = chatQueryContext.getQueryText(); + List terms = HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds()); EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class); - List matchResults = matchStrategy.getMatches(queryContext, terms); + List matchResults = matchStrategy.getMatches(chatQueryContext, terms); HanlpHelper.transLetterOriginal(matchResults); @@ -42,7 +42,7 @@ public class EmbeddingMapper extends BaseMapper { } SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type")); SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId, - queryContext.getSemanticSchema()); + chatQueryContext.getSemanticSchema()); if (schemaElement == null) { continue; } @@ -54,7 +54,7 @@ public class EmbeddingMapper extends BaseMapper { .detectWord(matchResult.getDetectWord()) .build(); //3. add to mapInfo - addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch); + addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch); } } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java index 64555b203..4faa118ee 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java @@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.Constants; import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQueryResult; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService; import lombok.extern.slf4j.Slf4j; @@ -49,13 +49,13 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { } @Override - public void detectByStep(QueryContext queryContext, Set existResults, + public void detectByStep(ChatQueryContext chatQueryContext, Set existResults, Set detectDataSetIds, String detectSegment, int offset) { } @Override - protected void detectByBatch(QueryContext queryContext, Set results, + protected void detectByBatch(ChatQueryContext chatQueryContext, Set results, Set detectDataSetIds, Set detectSegments) { int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN)); int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX)); @@ -72,16 +72,16 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy { embeddingMapperBatch); for (List queryTextsSub : queryTextsSubList) { - detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext); + detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext); } } private void detectByQueryTextsSub(Set results, Set detectDataSetIds, - List queryTextsSub, QueryContext queryContext) { - Map> modelIdToDataSetIds = queryContext.getModelIdToDataSetIds(); + List queryTextsSub, ChatQueryContext chatQueryContext) { + Map> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds(); double embeddingThreshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD)); double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN)); - double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, queryContext.getMapModeEnum()); + double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, chatQueryContext.getMapModeEnum()); // step1. build query params RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EntityMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EntityMapper.java index fd3a87a0e..9137784b8 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EntityMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EntityMapper.java @@ -6,7 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.util.CollectionUtils; @@ -21,14 +21,14 @@ import java.util.stream.Collectors; public class EntityMapper extends BaseMapper { @Override - public void doMap(QueryContext queryContext) { - SchemaMapInfo schemaMapInfo = queryContext.getMapInfo(); + public void doMap(ChatQueryContext chatQueryContext) { + SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo(); for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) { List schemaElementMatchList = schemaMapInfo.getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(schemaElementMatchList)) { continue; } - SchemaElement entity = getEntity(dataSetId, queryContext); + SchemaElement entity = getEntity(dataSetId, chatQueryContext); if (entity == null || entity.getId() == null) { continue; } @@ -64,8 +64,8 @@ public class EntityMapper extends BaseMapper { return false; } - private SchemaElement getEntity(Long dataSetId, QueryContext queryContext) { - SemanticSchema semanticSchema = queryContext.getSemanticSchema(); + private SchemaElement getEntity(Long dataSetId, ChatQueryContext chatQueryContext) { + SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); DataSetSchema modelSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); if (modelSchema != null && modelSchema.getEntity() != null) { return modelSchema.getEntity(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java index 14c1f7e80..f0e2cac07 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.mapper; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.headless.api.pojo.response.S2Term; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult; import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService; import lombok.extern.slf4j.Slf4j; @@ -37,16 +37,16 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { private KnowledgeBaseService knowledgeBaseService; @Override - public Map> match(QueryContext queryContext, List terms, + public Map> match(ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { - String text = queryContext.getQueryText(); + String text = chatQueryContext.getQueryText(); if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { return null; } log.debug("terms:{},detectModelIds:{}", terms, detectDataSetIds); - List detects = detect(queryContext, terms, detectDataSetIds); + List detects = detect(chatQueryContext, terms, detectDataSetIds); Map> result = new HashMap<>(); result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects); @@ -59,16 +59,17 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { && existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length(); } - public void detectByStep(QueryContext queryContext, Set existResults, Set detectDataSetIds, - String detectSegment, int offset) { + public void detectByStep(ChatQueryContext chatQueryContext, Set existResults, + Set detectDataSetIds, + String detectSegment, int offset) { // step1. pre search Integer oneDetectionMaxSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_MAX_SIZE)); LinkedHashSet hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment, - oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds) + oneDetectionMaxSize, chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds) .stream().collect(Collectors.toCollection(LinkedHashSet::new)); // step2. suffix search LinkedHashSet suffixHanlpMapResults = knowledgeBaseService.suffixSearch(detectSegment, - oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds) + oneDetectionMaxSize, chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds) .stream().collect(Collectors.toCollection(LinkedHashSet::new)); hanlpMapResults.addAll(suffixHanlpMapResults); @@ -83,7 +84,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { // step4. filter by similarity hanlpMapResults = hanlpMapResults.stream() .filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName()) - >= getThresholdMatch(term.getNatures(), queryContext)) + >= getThresholdMatch(term.getNatures(), chatQueryContext)) .filter(term -> CollectionUtils.isNotEmpty(term.getNatures())) .collect(Collectors.toCollection(LinkedHashSet::new)); @@ -126,7 +127,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures()); } - public double getThresholdMatch(List natures, QueryContext queryContext) { + public double getThresholdMatch(List natures, ChatQueryContext chatQueryContext) { Double threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD)); Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN)); if (mapperHelper.existDimensionValues(natures)) { @@ -134,7 +135,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy { minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD_MIN)); } - return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum()); + return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum()); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java index 4f38af0b7..d9581e54c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java @@ -6,12 +6,12 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.response.S2Term; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult; import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult; import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder; import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper; import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper; -import com.tencent.supersonic.headless.chat.QueryContext; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; @@ -30,23 +30,23 @@ import java.util.stream.Collectors; public class KeywordMapper extends BaseMapper { @Override - public void doMap(QueryContext queryContext) { - String queryText = queryContext.getQueryText(); + public void doMap(ChatQueryContext chatQueryContext) { + String queryText = chatQueryContext.getQueryText(); //1.hanlpDict Match - List terms = HanlpHelper.getTerms(queryText, queryContext.getModelIdToDataSetIds()); + List terms = HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds()); HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class); - List hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms); - convertHanlpMapResultToMapInfo(hanlpMapResults, queryContext, terms); + List hanlpMapResults = hanlpMatchStrategy.getMatches(chatQueryContext, terms); + convertHanlpMapResultToMapInfo(hanlpMapResults, chatQueryContext, terms); //2.database Match DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class); - List databaseResults = databaseMatchStrategy.getMatches(queryContext, terms); - convertDatabaseMapResultToMapInfo(queryContext, databaseResults); + List databaseResults = databaseMatchStrategy.getMatches(chatQueryContext, terms); + convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults); } - private void convertHanlpMapResultToMapInfo(List mapResults, QueryContext queryContext, + private void convertHanlpMapResultToMapInfo(List mapResults, ChatQueryContext chatQueryContext, List terms) { if (CollectionUtils.isEmpty(mapResults)) { return; @@ -68,7 +68,7 @@ public class KeywordMapper extends BaseMapper { } Long elementID = NatureHelper.getElementID(nature); SchemaElement element = getSchemaElement(dataSetId, elementType, - elementID, queryContext.getSemanticSchema()); + elementID, chatQueryContext.getSemanticSchema()); if (element == null) { continue; } @@ -81,16 +81,17 @@ public class KeywordMapper extends BaseMapper { .detectWord(hanlpMapResult.getDetectWord()) .build(); - addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch); + addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch); } } } - private void convertDatabaseMapResultToMapInfo(QueryContext queryContext, List mapResults) { + private void convertDatabaseMapResultToMapInfo(ChatQueryContext chatQueryContext, + List mapResults) { MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); for (DatabaseMapResult match : mapResults) { SchemaElement schemaElement = match.getSchemaElement(); - Set regElementSet = getRegElementSet(queryContext.getMapInfo(), schemaElement); + Set regElementSet = getRegElementSet(chatQueryContext.getMapInfo(), schemaElement); if (regElementSet.contains(schemaElement.getId())) { continue; } @@ -102,7 +103,7 @@ public class KeywordMapper extends BaseMapper { .similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName())) .build(); log.info("add to schema, elementMatch {}", schemaElementMatch); - addToSchemaMap(queryContext.getMapInfo(), schemaElement.getDataSet(), schemaElementMatch); + addToSchemaMap(chatQueryContext.getMapInfo(), schemaElement.getDataSet(), schemaElementMatch); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MatchStrategy.java index 0612eaae0..96289798b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MatchStrategy.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.mapper; import com.tencent.supersonic.headless.api.pojo.response.S2Term; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import java.util.List; import java.util.Map; @@ -14,6 +14,6 @@ import java.util.Set; */ public interface MatchStrategy { - Map> match(QueryContext queryContext, List terms, Set detectDataSetIds); + Map> match(ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds); } \ No newline at end of file diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/QueryFilterMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/QueryFilterMapper.java index 94db33bd1..76a2d44f5 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/QueryFilterMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/QueryFilterMapper.java @@ -8,8 +8,8 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder; -import com.tencent.supersonic.headless.chat.QueryContext; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; @@ -24,12 +24,12 @@ public class QueryFilterMapper extends BaseMapper { private double similarity = 1.0; @Override - public void doMap(QueryContext queryContext) { - Set dataSetIds = queryContext.getDataSetIds(); + public void doMap(ChatQueryContext chatQueryContext) { + Set dataSetIds = chatQueryContext.getDataSetIds(); if (CollectionUtils.isEmpty(dataSetIds)) { return; } - SchemaMapInfo schemaMapInfo = queryContext.getMapInfo(); + SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo(); clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo); for (Long dataSetId : dataSetIds) { List schemaElementMatches = schemaMapInfo.getMatchedElements(dataSetId); @@ -37,7 +37,7 @@ public class QueryFilterMapper extends BaseMapper { schemaElementMatches = Lists.newArrayList(); schemaMapInfo.setMatchedElements(dataSetId, schemaElementMatches); } - addValueSchemaElementMatch(dataSetId, queryContext, schemaElementMatches); + addValueSchemaElementMatch(dataSetId, chatQueryContext, schemaElementMatches); } } @@ -49,9 +49,9 @@ public class QueryFilterMapper extends BaseMapper { } } - private void addValueSchemaElementMatch(Long dataSetId, QueryContext queryContext, + private void addValueSchemaElementMatch(Long dataSetId, ChatQueryContext chatQueryContext, List candidateElementMatches) { - QueryFilters queryFilters = queryContext.getQueryFilters(); + QueryFilters queryFilters = chatQueryContext.getQueryFilters(); if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) { return; } @@ -75,7 +75,7 @@ public class QueryFilterMapper extends BaseMapper { .build(); candidateElementMatches.add(schemaElementMatch); } - queryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches); + chatQueryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches); } private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter, diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SchemaMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SchemaMapper.java index 4afcba726..15ee54723 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SchemaMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SchemaMapper.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.headless.chat.mapper; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; /** * A schema mapper identifies references to schema elements(metrics/dimensions/entities/values) @@ -9,5 +9,5 @@ import com.tencent.supersonic.headless.chat.QueryContext; */ public interface SchemaMapper { - void map(QueryContext queryContext); + void map(ChatQueryContext chatQueryContext); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java index f9cf8b88e..2bcea8b29 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.chat.mapper; import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.headless.api.pojo.response.S2Term; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult; import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService; import com.tencent.supersonic.headless.chat.knowledge.SearchService; @@ -32,9 +32,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy { private KnowledgeBaseService knowledgeBaseService; @Override - public Map> match(QueryContext queryContext, List originals, + public Map> match(ChatQueryContext chatQueryContext, List originals, Set detectDataSetIds) { - String text = queryContext.getQueryText(); + String text = chatQueryContext.getQueryText(); Map regOffsetToLength = getRegOffsetToLength(originals); List detectIndexList = Lists.newArrayList(); @@ -58,9 +58,14 @@ public class SearchMatchStrategy extends BaseMatchStrategy { if (StringUtils.isNotEmpty(detectSegment)) { List hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment, - SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds); + SearchService.SEARCH_SIZE, + chatQueryContext.getModelIdToDataSetIds(), + detectDataSetIds); List suffixHanlpMapResults = knowledgeBaseService.suffixSearch( - detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds); + detectSegment, + SEARCH_SIZE, + chatQueryContext.getModelIdToDataSetIds(), + detectDataSetIds); hanlpMapResults.addAll(suffixHanlpMapResults); // remove entity name where search hanlpMapResults = hanlpMapResults.stream().filter(entry -> { @@ -94,8 +99,8 @@ public class SearchMatchStrategy extends BaseMatchStrategy { } @Override - public void detectByStep(QueryContext queryContext, Set existResults, Set detectDataSetIds, - String detectSegment, int offset) { + public void detectByStep(ChatQueryContext chatQueryContext, Set existResults, + Set detectDataSetIds, String detectSegment, int offset) { } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java index afdab2d54..3d35587c7 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java @@ -12,7 +12,7 @@ import com.tencent.supersonic.headless.chat.query.SemanticQuery; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; @@ -29,21 +29,21 @@ import java.util.stream.Collectors; public class QueryTypeParser implements SemanticParser { @Override - public void parse(QueryContext queryContext, ChatContext chatContext) { + public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) { - List candidateQueries = queryContext.getCandidateQueries(); - User user = queryContext.getUser(); + List candidateQueries = chatQueryContext.getCandidateQueries(); + User user = chatQueryContext.getUser(); for (SemanticQuery semanticQuery : candidateQueries) { // 1.init S2SQL - semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user); + semanticQuery.initS2Sql(chatQueryContext.getSemanticSchema(), user); // 2.set queryType - QueryType queryType = getQueryType(queryContext, semanticQuery); + QueryType queryType = getQueryType(chatQueryContext, semanticQuery); semanticQuery.getParseInfo().setQueryType(queryType); } } - private QueryType getQueryType(QueryContext queryContext, SemanticQuery semanticQuery) { + private QueryType getQueryType(ChatQueryContext chatQueryContext, SemanticQuery semanticQuery) { SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); SqlInfo sqlInfo = parseInfo.getSqlInfo(); if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) { @@ -51,7 +51,7 @@ public class QueryTypeParser implements SemanticParser { } //1. entity queryType Long dataSetId = parseInfo.getDataSetId(); - SemanticSchema semanticSchema = queryContext.getSemanticSchema(); + SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) { List whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL()); List whereFilterByTimeFields = filterByTimeFields(whereFields); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SatisfactionChecker.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SatisfactionChecker.java index 875bd54f4..a945eaf3d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SatisfactionChecker.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SatisfactionChecker.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.query.SemanticQuery; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import lombok.extern.slf4j.Slf4j; @@ -21,12 +21,12 @@ import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_TE public class SatisfactionChecker { // check all the parse info in candidate - public static boolean isSkip(QueryContext queryContext) { - for (SemanticQuery query : queryContext.getCandidateQueries()) { + public static boolean isSkip(ChatQueryContext chatQueryContext) { + for (SemanticQuery query : chatQueryContext.getCandidateQueries()) { if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) { continue; } - if (checkThreshold(queryContext.getQueryText(), query.getParseInfo())) { + if (checkThreshold(chatQueryContext.getQueryText(), query.getParseInfo())) { return true; } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SemanticParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SemanticParser.java index 75d12073c..9ebef3acf 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SemanticParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SemanticParser.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.headless.chat.parser; import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; /** * A semantic parser understands user queries and generates semantic query statement. @@ -10,5 +10,5 @@ import com.tencent.supersonic.headless.chat.QueryContext; */ public interface SemanticParser { - void parse(QueryContext queryContext, ChatContext chatContext); + void parse(ChatQueryContext chatQueryContext, ChatContext chatContext); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/DataSetResolver.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/DataSetResolver.java index 6adfb7cbe..fe044a671 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/DataSetResolver.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/DataSetResolver.java @@ -1,12 +1,12 @@ package com.tencent.supersonic.headless.chat.parser.llm; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import java.util.Set; public interface DataSetResolver { - Long resolve(QueryContext queryContext, Set restrictiveModels); + Long resolve(ChatQueryContext chatQueryContext, Set restrictiveModels); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java index ec50edded..96d98e6ac 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java @@ -4,7 +4,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.chat.query.SemanticQuery; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import java.util.ArrayList; @@ -111,8 +111,8 @@ public class HeuristicDataSetResolver implements DataSetResolver { return dataSetCount; } - public Long resolve(QueryContext queryContext, Set agentDataSetIds) { - SchemaMapInfo mapInfo = queryContext.getMapInfo(); + public Long resolve(ChatQueryContext chatQueryContext, Set agentDataSetIds) { + SchemaMapInfo mapInfo = chatQueryContext.getMapInfo(); Set matchedDataSets = mapInfo.getMatchedDataSetInfos(); if (CollectionUtils.isNotEmpty(agentDataSetIds)) { matchedDataSets.retainAll(agentDataSetIds); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 227a45ad0..2e964a8a8 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -8,7 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; @@ -43,7 +43,7 @@ public class LLMRequestService { @Autowired private ParserConfig parserConfig; - public boolean isSkip(QueryContext queryCtx) { + public boolean isSkip(ChatQueryContext queryCtx) { if (!queryCtx.getText2SQLType().enableLLM()) { log.info("not enable llm, skip"); return true; @@ -57,12 +57,12 @@ public class LLMRequestService { return false; } - public Long getDataSetId(QueryContext queryCtx) { + public Long getDataSetId(ChatQueryContext queryCtx) { DataSetResolver dataSetResolver = ComponentFactory.getModelResolver(); return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds()); } - public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId) { + public LLMReq getLlmReq(ChatQueryContext queryCtx, Long dataSetId) { LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); List linkingValues = requestService.getValues(queryCtx, dataSetId); SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); @@ -118,7 +118,7 @@ public class LLMRequestService { return result; } - protected List getFieldNameList(QueryContext queryCtx, Long dataSetId, + protected List getFieldNameList(ChatQueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) { Set results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig); @@ -129,7 +129,7 @@ public class LLMRequestService { return new ArrayList<>(results); } - protected List getTerms(QueryContext queryCtx, Long dataSetId) { + protected List getTerms(ChatQueryContext queryCtx, Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { return new ArrayList<>(); @@ -147,7 +147,7 @@ public class LLMRequestService { }).collect(Collectors.toList()); } - private String getPriorExts(QueryContext queryContext, List fieldNameList) { + private String getPriorExts(ChatQueryContext queryContext, List fieldNameList) { StringBuilder extraInfoSb = new StringBuilder(); SemanticSchema semanticSchema = queryContext.getSemanticSchema(); Map fieldNameToDataFormatType = semanticSchema.getMetrics() @@ -176,7 +176,7 @@ public class LLMRequestService { return extraInfoSb.toString(); } - public List getValues(QueryContext queryCtx, Long dataSetId) { + public List getValues(ChatQueryContext queryCtx, Long dataSetId) { Map itemIdToName = getItemIdToName(queryCtx, dataSetId); List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { @@ -198,14 +198,14 @@ public class LLMRequestService { return new ArrayList<>(valueMatches); } - protected Map getItemIdToName(QueryContext queryCtx, Long dataSetId) { + protected Map getItemIdToName(ChatQueryContext queryCtx, Long dataSetId) { SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); List elements = semanticSchema.getDimensions(dataSetId); return elements.stream() .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); } - private Set getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) { + private Set getTopNFieldNames(ChatQueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) { SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); Set results = new HashSet<>(); Set dimensions = semanticSchema.getDimensions(dataSetId).stream() @@ -223,7 +223,7 @@ public class LLMRequestService { return results; } - protected List getMatchedMetrics(QueryContext queryCtx, Long dataSetId) { + protected List getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { return Collections.emptyList(); @@ -240,7 +240,7 @@ public class LLMRequestService { return schemaElements; } - protected List getMatchedDimensions(QueryContext queryCtx, Long dataSetId) { + protected List getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { return Collections.emptyList(); @@ -257,7 +257,7 @@ public class LLMRequestService { return schemaElements; } - protected Set getMatchedFieldNames(QueryContext queryCtx, Long dataSetId) { + protected Set getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) { Map itemIdToName = getItemIdToName(queryCtx, dataSetId); List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java index 1146daa3d..4419a3217 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java @@ -8,7 +8,7 @@ import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import java.util.ArrayList; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.MapUtils; @@ -22,7 +22,8 @@ import java.util.Objects; @Service public class LLMResponseService { - public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) { + public SemanticParseInfo addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult, + String s2SQL, Double weight) { if (Objects.isNull(weight)) { weight = 0D; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java index 3855cc69b..33ce39e8c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.parser.SemanticParser; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; @@ -23,7 +23,7 @@ import org.apache.commons.collections.MapUtils; public class LLMSqlParser implements SemanticParser { @Override - public void parse(QueryContext queryCtx, ChatContext chatCtx) { + public void parse(ChatQueryContext queryCtx, ChatContext chatCtx) { try { LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); //1.determine whether to skip this parser. @@ -44,7 +44,7 @@ public class LLMSqlParser implements SemanticParser { } } - private void tryParse(QueryContext queryCtx, Long dataSetId) { + private void tryParse(ChatQueryContext queryCtx, Long dataSetId) { LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class); LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class); int maxRetries = ContextUtils.getBean(LLMParserConfig.class).getRecallMaxRetries(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java index f6f2c1d93..1399f66bb 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.rule; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.query.SemanticQuery; import com.tencent.supersonic.headless.chat.parser.SemanticParser; import lombok.AllArgsConstructor; @@ -41,11 +41,11 @@ public class AggregateTypeParser implements SemanticParser { ).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2)); @Override - public void parse(QueryContext queryContext, ChatContext chatContext) { - String queryText = queryContext.getQueryText(); + public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) { + String queryText = chatQueryContext.getQueryText(); AggregateConf aggregateConf = resolveAggregateConf(queryText); - for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) { + for (SemanticQuery semanticQuery : chatQueryContext.getCandidateQueries()) { if (!AggregateTypeEnum.NONE.equals(semanticQuery.getParseInfo().getAggType())) { continue; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java index a3fa00366..c268dc537 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java @@ -2,12 +2,12 @@ package com.tencent.supersonic.headless.chat.parser.rule; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.SemanticQuery; import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.chat.parser.SemanticParser; import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery; @@ -43,16 +43,16 @@ public class ContextInheritParser implements SemanticParser { ).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); @Override - public void parse(QueryContext queryContext, ChatContext chatContext) { - if (!shouldInherit(queryContext)) { + public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) { + if (!shouldInherit(chatQueryContext)) { return; } - Long dataSetId = getMatchedDataSet(queryContext, chatContext); + Long dataSetId = getMatchedDataSet(chatQueryContext, chatContext); if (dataSetId == null) { return; } - List elementMatches = queryContext.getMapInfo().getMatchedElements(dataSetId); + List elementMatches = chatQueryContext.getMapInfo().getMatchedElements(dataSetId); List matchesToInherit = new ArrayList<>(); for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) { @@ -66,18 +66,18 @@ public class ContextInheritParser implements SemanticParser { } elementMatches.addAll(matchesToInherit); - List queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, queryContext); + List queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext); for (RuleSemanticQuery query : queries) { - query.fillParseInfo(queryContext, chatContext); - if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), queryContext)) { + query.fillParseInfo(chatQueryContext, chatContext); + if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), chatQueryContext)) { continue; } - queryContext.getCandidateQueries().add(query); + chatQueryContext.getCandidateQueries().add(query); } } - private boolean existSameQuery(Long dataSetId, String queryMode, QueryContext queryContext) { - for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) { + private boolean existSameQuery(Long dataSetId, String queryMode, ChatQueryContext chatQueryContext) { + for (SemanticQuery semanticQuery : chatQueryContext.getCandidateQueries()) { if (semanticQuery.getQueryMode().equals(queryMode) && semanticQuery.getParseInfo().getDataSetId().equals(dataSetId)) { return true; @@ -100,20 +100,20 @@ public class ContextInheritParser implements SemanticParser { }); } - protected boolean shouldInherit(QueryContext queryContext) { + protected boolean shouldInherit(ChatQueryContext chatQueryContext) { // if candidates only have MetricModel mode, count in context - List metricModelQueries = queryContext.getCandidateQueries().stream() + List metricModelQueries = chatQueryContext.getCandidateQueries().stream() .filter(query -> query instanceof MetricModelQuery).collect( Collectors.toList()); - return metricModelQueries.size() == queryContext.getCandidateQueries().size(); + return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size(); } - protected Long getMatchedDataSet(QueryContext queryContext, ChatContext chatContext) { + protected Long getMatchedDataSet(ChatQueryContext chatQueryContext, ChatContext chatContext) { Long dataSetId = chatContext.getParseInfo().getDataSetId(); if (dataSetId == null) { return null; } - Set queryDataSets = queryContext.getMapInfo().getMatchedDataSetInfos(); + Set queryDataSets = chatQueryContext.getMapInfo().getMatchedDataSetInfos(); if (queryDataSets.contains(dataSetId)) { return dataSetId; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java index 46f55f041..7733b37c8 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java @@ -2,10 +2,10 @@ package com.tencent.supersonic.headless.chat.parser.rule; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.chat.parser.SemanticParser; import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; import lombok.extern.slf4j.Slf4j; import java.util.Arrays; import java.util.List; @@ -24,21 +24,21 @@ public class RuleSqlParser implements SemanticParser { ); @Override - public void parse(QueryContext queryContext, ChatContext chatContext) { - if (!queryContext.getText2SQLType().enableRule()) { + public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) { + if (!chatQueryContext.getText2SQLType().enableRule()) { return; } - SchemaMapInfo mapInfo = queryContext.getMapInfo(); + SchemaMapInfo mapInfo = chatQueryContext.getMapInfo(); // iterate all schemaElementMatches to resolve query mode for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) { List elementMatches = mapInfo.getMatchedElements(dataSetId); - List queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, queryContext); + List queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext); for (RuleSemanticQuery query : queries) { - query.fillParseInfo(queryContext, chatContext); - queryContext.getCandidateQueries().add(query); + query.fillParseInfo(chatQueryContext, chatContext); + chatQueryContext.getCandidateQueries().add(query); } } - auxiliaryParsers.stream().forEach(p -> p.parse(queryContext, chatContext)); + auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext, chatContext)); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java index 8fe5c6168..d03e598da 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.chat.parser.rule; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.parser.SemanticParser; import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.SemanticQuery; @@ -42,7 +42,7 @@ public class TimeRangeParser implements SemanticParser { private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd"); @Override - public void parse(QueryContext queryContext, ChatContext chatContext) { + public void parse(ChatQueryContext queryContext, ChatContext chatContext) { String queryText = queryContext.getQueryText(); DateConf dateConf = parseRecent(queryText); if (dateConf == null) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java index 92d5a28d4..9eae71548 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java @@ -12,7 +12,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery; import com.tencent.supersonic.headless.chat.query.QueryManager; @@ -40,7 +40,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { } public List match(List candidateElementMatches, - QueryContext queryCtx) { + ChatQueryContext queryCtx) { return queryMatcher.match(candidateElementMatches); } @@ -49,9 +49,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { initS2SqlByStruct(semanticSchema); } - public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) { + public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) { parseInfo.setQueryMode(getQueryMode()); - SemanticSchema semanticSchema = queryContext.getSemanticSchema(); + SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); fillSchemaElement(parseInfo, semanticSchema); fillScore(parseInfo); @@ -223,10 +223,10 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { } public static List resolve(Long dataSetId, List candidateElementMatches, - QueryContext queryContext) { + ChatQueryContext chatQueryContext) { List matchedQueries = new ArrayList<>(); for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) { - List matches = semanticQuery.match(candidateElementMatches, queryContext); + List matches = semanticQuery.match(candidateElementMatches, chatQueryContext); if (matches.size() > 0) { RuleSemanticQuery query = QueryManager.createRuleQuery(semanticQuery.getQueryMode()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailListQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailListQuery.java index c646aed1b..256a1eb72 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailListQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailListQuery.java @@ -8,7 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig; import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import org.apache.commons.collections.CollectionUtils; import java.util.LinkedHashSet; @@ -19,15 +19,15 @@ import java.util.stream.Collectors; public abstract class DetailListQuery extends DetailSemanticQuery { @Override - public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) { - super.fillParseInfo(queryContext, chatContext); - this.addEntityDetailAndOrderByMetric(queryContext, parseInfo); + public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) { + super.fillParseInfo(chatQueryContext, chatContext); + this.addEntityDetailAndOrderByMetric(chatQueryContext, parseInfo); } - private void addEntityDetailAndOrderByMetric(QueryContext queryContext, SemanticParseInfo parseInfo) { + private void addEntityDetailAndOrderByMetric(ChatQueryContext chatQueryContext, SemanticParseInfo parseInfo) { Long dataSetId = parseInfo.getDataSetId(); if (Objects.nonNull(dataSetId) && dataSetId > 0L) { - DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); + DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); if (dataSetSchema != null && Objects.nonNull(dataSetSchema.getEntity())) { Set dimensions = new LinkedHashSet<>(); Set metrics = new LinkedHashSet<>(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailSemanticQuery.java index 2d872ab01..a0cba04eb 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailSemanticQuery.java @@ -7,7 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption; import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.chat.ChatContext; @@ -30,19 +30,19 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery { @Override public List match(List candidateElementMatches, - QueryContext queryCtx) { + ChatQueryContext queryCtx) { return super.match(candidateElementMatches, queryCtx); } @Override - public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) { - super.fillParseInfo(queryContext, chatContext); + public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) { + super.fillParseInfo(chatQueryContext, chatContext); parseInfo.setQueryType(QueryType.DETAIL); parseInfo.setLimit(DETAIL_MAX_RESULTS); if (parseInfo.getDateInfo() == null) { DataSetSchema dataSetSchema = - queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId()); + chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId()); TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig(); DateConf dateInfo = new DateConf(); if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricSemanticQuery.java index 7b02ab1a6..931be856a 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricSemanticQuery.java @@ -6,7 +6,7 @@ import com.tencent.supersonic.common.pojo.enums.TimeMode; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.chat.ChatContext; import lombok.extern.slf4j.Slf4j; @@ -31,17 +31,17 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery { @Override public List match(List candidateElementMatches, - QueryContext queryCtx) { + ChatQueryContext queryCtx) { return super.match(candidateElementMatches, queryCtx); } @Override - public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) { - super.fillParseInfo(queryContext, chatContext); + public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) { + super.fillParseInfo(chatQueryContext, chatContext); parseInfo.setLimit(METRIC_MAX_RESULTS); if (parseInfo.getDateInfo() == null) { DataSetSchema dataSetSchema = - queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId()); + chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId()); TimeDefaultConfig timeDefaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig(); DateConf dateInfo = new DateConf(); if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java index 97fbafb33..0ffe301bc 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java @@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.Order; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatContext; import org.springframework.stereotype.Component; @@ -36,7 +36,7 @@ public class MetricTopNQuery extends MetricSemanticQuery { @Override public List match(List candidateElementMatches, - QueryContext queryCtx) { + ChatQueryContext queryCtx) { Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getQueryText()); if (matcher.matches()) { return super.match(candidateElementMatches, queryCtx); @@ -50,8 +50,8 @@ public class MetricTopNQuery extends MetricSemanticQuery { } @Override - public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) { - super.fillParseInfo(queryContext, chatContext); + public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) { + super.fillParseInfo(chatQueryContext, chatContext); parseInfo.setLimit(ORDERBY_MAX_RESULTS); parseInfo.setScore(parseInfo.getScore() + 2.0); diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/AggCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/AggCorrectorTest.java index f81e886df..013b01b18 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/AggCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/AggCorrectorTest.java @@ -6,7 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SqlInfo; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import org.junit.Assert; import org.junit.jupiter.api.Test; @@ -21,7 +21,7 @@ class AggCorrectorTest { void testDoCorrect() { AggCorrector corrector = new AggCorrector(); Long dataSetId = 1L; - QueryContext queryContext = buildQueryContext(dataSetId); + ChatQueryContext chatQueryContext = buildQueryContext(dataSetId); SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SchemaElement dataSet = new SchemaElement(); dataSet.setDataSet(dataSetId); @@ -33,15 +33,15 @@ class AggCorrectorTest { sqlInfo.setS2SQL(sql); sqlInfo.setCorrectS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); - corrector.correct(queryContext, semanticParseInfo); + corrector.correct(chatQueryContext, semanticParseInfo); Assert.assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'" + " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户" + " ORDER BY SUM(访问次数) DESC LIMIT 1", semanticParseInfo.getSqlInfo().getCorrectS2SQL()); } - private QueryContext buildQueryContext(Long dataSetId) { - QueryContext queryContext = new QueryContext(); + private ChatQueryContext buildQueryContext(Long dataSetId) { + ChatQueryContext chatQueryContext = new ChatQueryContext(); List dataSetSchemaList = new ArrayList<>(); DataSetSchema dataSetSchema = new DataSetSchema(); QueryConfig queryConfig = new QueryConfig(); @@ -67,8 +67,8 @@ class AggCorrectorTest { dataSetSchemaList.add(dataSetSchema); SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList); - queryContext.setSemanticSchema(semanticSchema); - return queryContext; + chatQueryContext.setSemanticSchema(semanticSchema); + return chatQueryContext; } } diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java index f2e76e4a2..ce1ee737e 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java @@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.chat.parser.llm.ParseResult; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import org.junit.Assert; import org.junit.jupiter.api.Test; @@ -56,7 +56,7 @@ class SchemaCorrectorTest { @Test void doCorrect() throws JsonProcessingException { Long dataSetId = 1L; - QueryContext queryContext = buildQueryContext(dataSetId); + ChatQueryContext chatQueryContext = buildQueryContext(dataSetId); ObjectMapper objectMapper = new ObjectMapper(); ParseResult parseResult = objectMapper.readValue(json, ParseResult.class); @@ -77,7 +77,7 @@ class SchemaCorrectorTest { semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult); SchemaCorrector schemaCorrector = new SchemaCorrector(); - schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo); + schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo); Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL()); @@ -94,14 +94,14 @@ class SchemaCorrectorTest { semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql); semanticParseInfo.getSqlInfo().setS2SQL(sql); - schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo); + schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo); Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' " + "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL()); } - private QueryContext buildQueryContext(Long dataSetId) { - QueryContext queryContext = new QueryContext(); + private ChatQueryContext buildQueryContext(Long dataSetId) { + ChatQueryContext chatQueryContext = new ChatQueryContext(); List dataSetSchemaList = new ArrayList<>(); DataSetSchema dataSetSchema = new DataSetSchema(); QueryConfig queryConfig = new QueryConfig(); @@ -129,7 +129,7 @@ class SchemaCorrectorTest { dataSetSchemaList.add(dataSetSchema); SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList); - queryContext.setSemanticSchema(semanticSchema); - return queryContext; + chatQueryContext.setSemanticSchema(semanticSchema); + return chatQueryContext; } } \ No newline at end of file diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java index 51474df41..669524c74 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java @@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import org.junit.Assert; import org.junit.jupiter.api.Test; import org.mockito.MockedStatic; @@ -36,7 +36,7 @@ class SelectCorrectorTest { when(mockEnvironment.getProperty(SelectCorrector.ADDITIONAL_INFORMATION)).thenReturn(""); BaseSemanticCorrector corrector = new SelectCorrector(); - QueryContext queryContext = buildQueryContext(dataSetId); + ChatQueryContext chatQueryContext = buildQueryContext(dataSetId); SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SchemaElement dataSet = new SchemaElement(); dataSet.setDataSet(dataSetId); @@ -47,13 +47,13 @@ class SelectCorrectorTest { sqlInfo.setS2SQL(sql); sqlInfo.setCorrectS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); - corrector.correct(queryContext, semanticParseInfo); + corrector.correct(chatQueryContext, semanticParseInfo); Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'", semanticParseInfo.getSqlInfo().getCorrectS2SQL()); } - private QueryContext buildQueryContext(Long dataSetId) { - QueryContext queryContext = new QueryContext(); + private ChatQueryContext buildQueryContext(Long dataSetId) { + ChatQueryContext chatQueryContext = new ChatQueryContext(); List dataSetSchemaList = new ArrayList<>(); DataSetSchema dataSetSchema = new DataSetSchema(); QueryConfig queryConfig = new QueryConfig(); @@ -108,7 +108,7 @@ class SelectCorrectorTest { dataSetSchemaList.add(dataSetSchema); SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList); - queryContext.setSemanticSchema(semanticSchema); - return queryContext; + chatQueryContext.setSemanticSchema(semanticSchema); + return chatQueryContext; } } \ No newline at end of file diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrectorTest.java index 0e54ad5a0..b540ef1d5 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrectorTest.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.corrector; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SqlInfo; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import org.junit.Assert; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -13,7 +13,7 @@ class TimeCorrectorTest { @Test void testDoCorrect() { TimeCorrector corrector = new TimeCorrector(); - QueryContext queryContext = new QueryContext(); + ChatQueryContext chatQueryContext = new ChatQueryContext(); SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SqlInfo sqlInfo = new SqlInfo(); //1.数据日期 <= @@ -21,7 +21,7 @@ class TimeCorrectorTest { + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; sqlInfo.setCorrectS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); - corrector.doCorrect(queryContext, semanticParseInfo); + corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') " @@ -32,7 +32,7 @@ class TimeCorrectorTest { sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1"; sqlInfo.setCorrectS2SQL(sql); - corrector.doCorrect(queryContext, semanticParseInfo); + corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') " @@ -43,7 +43,7 @@ class TimeCorrectorTest { sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1"; sqlInfo.setCorrectS2SQL(sql); - corrector.doCorrect(queryContext, semanticParseInfo); + corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 " @@ -54,7 +54,7 @@ class TimeCorrectorTest { sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1"; sqlInfo.setCorrectS2SQL(sql); - corrector.doCorrect(queryContext, semanticParseInfo); + corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 " @@ -65,7 +65,7 @@ class TimeCorrectorTest { sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE 歌手名 = '张三' GROUP BY 维度1"; sqlInfo.setCorrectS2SQL(sql); - corrector.doCorrect(queryContext, semanticParseInfo); + corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1", @@ -75,7 +75,7 @@ class TimeCorrectorTest { sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1"; sqlInfo.setCorrectS2SQL(sql); - corrector.doCorrect(queryContext, semanticParseInfo); + corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') " @@ -86,7 +86,7 @@ class TimeCorrectorTest { sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1"; sqlInfo.setCorrectS2SQL(sql); - corrector.doCorrect(queryContext, semanticParseInfo); + corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals( "SELECT 维度1, SUM(播放量) FROM 数据库 " @@ -96,7 +96,7 @@ class TimeCorrectorTest { //8. no where sql = "SELECT COUNT(1) FROM 数据库"; sqlInfo.setCorrectS2SQL(sql); - corrector.doCorrect(queryContext, semanticParseInfo); + corrector.doCorrect(chatQueryContext, semanticParseInfo); Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL()); } } diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java index 1b385f0c4..7b2a5e811 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java @@ -6,7 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import org.junit.Assert; import org.junit.jupiter.api.Test; @@ -22,7 +22,7 @@ class WhereCorrectorTest { sqlInfo.setCorrectS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); - QueryContext queryContext = new QueryContext(); + ChatQueryContext chatQueryContext = new ChatQueryContext(); QueryFilter filter1 = new QueryFilter(); filter1.setName("age"); @@ -49,10 +49,10 @@ class WhereCorrectorTest { queryFilters.getFilters().add(filter2); queryFilters.getFilters().add(filter3); queryFilters.getFilters().add(filter4); - queryContext.setQueryFilters(queryFilters); + chatQueryContext.setQueryFilters(queryFilters); WhereCorrector whereCorrector = new WhereCorrector(); - whereCorrector.addQueryFilter(queryContext, semanticParseInfo); + whereCorrector.addQueryFilter(chatQueryContext, semanticParseInfo); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/S2SqlDateHelperTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/S2SqlDateHelperTest.java index fc344c511..dd3d70bf2 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/S2SqlDateHelperTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/S2SqlDateHelperTest.java @@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.QueryConfig; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.corrector.S2SqlDateHelper; import org.apache.commons.lang3.tuple.Pair; import org.junit.Assert; @@ -26,15 +26,15 @@ class S2SqlDateHelperTest { @Test void getReferenceDate() { Long dataSetId = 1L; - QueryContext queryContext = buildQueryContext(dataSetId); + ChatQueryContext chatQueryContext = buildQueryContext(dataSetId); - String referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, null); + String referenceDate = S2SqlDateHelper.getReferenceDate(chatQueryContext, null); Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0)); - referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId); + referenceDate = S2SqlDateHelper.getReferenceDate(chatQueryContext, dataSetId); Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0)); - DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); + DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); QueryConfig queryConfig = dataSetSchema.getQueryConfig(); TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig(); timeDefaultConfig.setTimeMode(TimeMode.LAST); @@ -42,32 +42,32 @@ class S2SqlDateHelperTest { timeDefaultConfig.setUnit(20); queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig); - referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId); + referenceDate = S2SqlDateHelper.getReferenceDate(chatQueryContext, dataSetId); Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(20)); timeDefaultConfig.setUnit(1); - referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId); + referenceDate = S2SqlDateHelper.getReferenceDate(chatQueryContext, dataSetId); Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(1)); timeDefaultConfig.setUnit(-1); - referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId); + referenceDate = S2SqlDateHelper.getReferenceDate(chatQueryContext, dataSetId); Assert.assertNull(referenceDate); } @Test void getStartEndDate() { Long dataSetId = 1L; - QueryContext queryContext = buildQueryContext(dataSetId); + ChatQueryContext chatQueryContext = buildQueryContext(dataSetId); - Pair startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, null, QueryType.DETAIL); + Pair startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, null, QueryType.DETAIL); Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(0)); Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(0)); - startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL); + startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.DETAIL); Assert.assertNotNull(startEndDate.getLeft()); Assert.assertNotNull(startEndDate.getRight()); - DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); + DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); QueryConfig queryConfig = dataSetSchema.getQueryConfig(); TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig(); timeDefaultConfig.setTimeMode(TimeMode.LAST); @@ -76,39 +76,39 @@ class S2SqlDateHelperTest { queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig); queryConfig.getMetricTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig); - startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL); + startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.DETAIL); Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20)); Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20)); - startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC); + startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.METRIC); Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20)); Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20)); timeDefaultConfig.setUnit(2); timeDefaultConfig.setTimeMode(TimeMode.RECENT); - startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC); + startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.METRIC); Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2)); Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1)); - startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL); + startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.DETAIL); Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2)); Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1)); timeDefaultConfig.setUnit(-1); - startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC); + startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.METRIC); Assert.assertNull(startEndDate.getLeft()); Assert.assertNull(startEndDate.getRight()); timeDefaultConfig.setTimeMode(TimeMode.LAST); timeDefaultConfig.setPeriod(Constants.DAY); timeDefaultConfig.setUnit(5); - startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC); + startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.METRIC); Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(5)); Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(5)); } - private QueryContext buildQueryContext(Long dataSetId) { - QueryContext queryContext = new QueryContext(); + private ChatQueryContext buildQueryContext(Long dataSetId) { + ChatQueryContext chatQueryContext = new ChatQueryContext(); List dataSetSchemaList = new ArrayList<>(); DataSetSchema dataSetSchema = new DataSetSchema(); QueryConfig queryConfig = new QueryConfig(); @@ -119,7 +119,7 @@ class S2SqlDateHelperTest { dataSetSchemaList.add(dataSetSchema); SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList); - queryContext.setSemanticSchema(semanticSchema); - return queryContext; + chatQueryContext.setSemanticSchema(semanticSchema); + return chatQueryContext; } } \ No newline at end of file diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/ChatQueryServiceImpl.java index ab5c1fad0..962b69904 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/ChatQueryServiceImpl.java @@ -49,7 +49,7 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector; import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector; import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult; @@ -63,7 +63,7 @@ import com.tencent.supersonic.headless.chat.query.SemanticQuery; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.server.facade.service.ChatQueryService; import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; -import com.tencent.supersonic.headless.server.utils.WorkflowEngine; +import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine; import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO; import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.utils.ComponentFactory; @@ -115,12 +115,12 @@ public class ChatQueryServiceImpl implements ChatQueryService { @Autowired private DataSetService dataSetService; @Autowired - private WorkflowEngine workflowEngine; + private ChatWorkflowEngine chatWorkflowEngine; @Override public MapResp performMapping(QueryTextReq queryTextReq) { MapResp mapResp = new MapResp(); - QueryContext queryCtx = buildQueryContext(queryTextReq); + ChatQueryContext queryCtx = buildQueryContext(queryTextReq); ComponentFactory.getSchemaMappers().forEach(mapper -> { mapper.map(queryCtx); }); @@ -148,12 +148,12 @@ public class ChatQueryServiceImpl implements ChatQueryService { public ParseResp performParsing(QueryTextReq queryTextReq) { ParseResp parseResult = new ParseResp(queryTextReq.getChatId(), queryTextReq.getQueryText()); // build queryContext and chatContext - QueryContext queryCtx = buildQueryContext(queryTextReq); + ChatQueryContext queryCtx = buildQueryContext(queryTextReq); // in order to support multi-turn conversation, chat context is needed ChatContext chatCtx = chatContextService.getOrCreateContext(queryTextReq.getChatId()); - workflowEngine.startWorkflow(queryCtx, chatCtx, parseResult); + chatWorkflowEngine.execute(queryCtx, chatCtx, parseResult); List parseInfos = queryCtx.getCandidateQueries().stream() .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); @@ -161,11 +161,11 @@ public class ChatQueryServiceImpl implements ChatQueryService { return parseResult; } - public QueryContext buildQueryContext(QueryTextReq queryTextReq) { + public ChatQueryContext buildQueryContext(QueryTextReq queryTextReq) { SemanticSchema semanticSchema = schemaService.getSemanticSchema(); Map> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(); - QueryContext queryCtx = QueryContext.builder() + ChatQueryContext queryCtx = ChatQueryContext.builder() .queryFilters(queryTextReq.getQueryFilters()) .semanticSchema(semanticSchema) .candidateQueries(new ArrayList<>()) @@ -612,7 +612,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private SemanticParseInfo correctSqlReq(QuerySqlReq querySqlReq, User user) { - QueryContext queryCtx = new QueryContext(); + ChatQueryContext queryCtx = new ChatQueryContext(); SemanticSchema semanticSchema = schemaService.getSemanticSchema(); queryCtx.setSemanticSchema(semanticSchema); SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/RetrieveServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/RetrieveServiceImpl.java index 6ef664859..3badfebc9 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/RetrieveServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/RetrieveServiceImpl.java @@ -11,7 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.api.pojo.response.SearchResult; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.DataSetInfoStat; import com.tencent.supersonic.headless.chat.knowledge.DictWord; import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult; @@ -78,12 +78,12 @@ public class RetrieveServiceImpl implements RetrieveService { log.debug("hanlp parse result: {}", originals); Set dataSetIds = queryTextReq.getDataSetIds(); - QueryContext queryContext = new QueryContext(); - BeanUtils.copyProperties(queryTextReq, queryContext); - queryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds()); + ChatQueryContext chatQueryContext = new ChatQueryContext(); + BeanUtils.copyProperties(queryTextReq, chatQueryContext); + chatQueryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds()); Map> regTextMap = - searchMatchStrategy.match(queryContext, originals, dataSetIds); + searchMatchStrategy.match(chatQueryContext, originals, dataSetIds); regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue())); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java index 29db0b7cf..e7408fb81 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java @@ -15,7 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.query.SemanticQuery; import com.tencent.supersonic.headless.server.web.service.SchemaService; import lombok.extern.slf4j.Slf4j; @@ -40,8 +40,8 @@ import java.util.stream.Collectors; public class ParseInfoProcessor implements ResultProcessor { @Override - public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) { - List candidateQueries = queryContext.getCandidateQueries(); + public void process(ParseResp parseResp, ChatQueryContext chatQueryContext, ChatContext chatContext) { + List candidateQueries = chatQueryContext.getCandidateQueries(); if (CollectionUtils.isEmpty(candidateQueries)) { return; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ResultProcessor.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ResultProcessor.java index f05e3211c..02d24642f 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ResultProcessor.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ResultProcessor.java @@ -2,13 +2,13 @@ package com.tencent.supersonic.headless.server.processor; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; /** * A ParseResultProcessor wraps things up before returning results to users in parse stage. */ public interface ResultProcessor { - void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext); + void process(ParseResp parseResp, ChatQueryContext chatQueryContext, ChatContext chatContext); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/SqlInfoProcessor.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/SqlInfoProcessor.java deleted file mode 100644 index b2b26bc06..000000000 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/SqlInfoProcessor.java +++ /dev/null @@ -1,88 +0,0 @@ -package com.tencent.supersonic.headless.server.processor; - -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.api.pojo.SqlInfo; -import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; -import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq; -import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import com.tencent.supersonic.headless.api.pojo.response.TranslateResp; -import com.tencent.supersonic.headless.api.pojo.response.ParseResp; -import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; -import com.tencent.supersonic.headless.chat.query.QueryManager; -import com.tencent.supersonic.headless.chat.query.SemanticQuery; -import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; -import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.util.CollectionUtils; -import java.util.List; -import java.util.Objects; -import java.util.stream.Collectors; - -/** - * SqlInfoProcessor adds intermediate S2SQL and final SQL to the parsing results - * so that technical users could verify SQL by themselves. - **/ -@Slf4j -public class SqlInfoProcessor implements ResultProcessor { - - private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - - @Override - public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) { - long start = System.currentTimeMillis(); - List semanticQueries = queryContext.getCandidateQueries(); - if (CollectionUtils.isEmpty(semanticQueries)) { - return; - } - List selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo) - .collect(Collectors.toList()); - addSqlInfo(queryContext, selectedParses); - parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start); - } - - private void addSqlInfo(QueryContext queryContext, List semanticParseInfos) { - if (CollectionUtils.isEmpty(semanticParseInfos)) { - return; - } - semanticParseInfos.forEach(parseInfo -> { - try { - addSqlInfo(queryContext, parseInfo); - } catch (Exception e) { - log.warn("get sql info failed:{}", parseInfo, e); - } - }); - } - - private void addSqlInfo(QueryContext queryContext, SemanticParseInfo parseInfo) throws Exception { - SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode()); - if (Objects.isNull(semanticQuery)) { - return; - } - semanticQuery.setParseInfo(parseInfo); - SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); - SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class); - TranslateSqlReq translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq) - .queryTypeEnum(QueryMethod.SQL).build(); - TranslateResp explain = queryService.translate(translateSqlReq, queryContext.getUser()); - String querySql = explain.getSql(); - if (StringUtils.isBlank(querySql)) { - return; - } - SqlInfo sqlInfo = parseInfo.getSqlInfo(); - if (semanticQuery instanceof LLMSqlQuery) { - keyPipelineLog.info("SqlInfoProcessor results:\n" - + "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}", - StringUtils.normalizeSpace(sqlInfo.getS2SQL()), - StringUtils.normalizeSpace(sqlInfo.getCorrectS2SQL()), - StringUtils.normalizeSpace(querySql)); - } - sqlInfo.setQuerySQL(querySql); - sqlInfo.setSourceId(explain.getSourceId()); - } - -} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java new file mode 100644 index 000000000..b052fc435 --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -0,0 +1,148 @@ +package com.tencent.supersonic.headless.server.utils; + +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.SqlInfo; +import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState; +import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; +import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; +import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq; +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; +import com.tencent.supersonic.headless.api.pojo.response.TranslateResp; +import com.tencent.supersonic.headless.chat.ChatContext; +import com.tencent.supersonic.headless.chat.ChatQueryContext; +import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector; +import com.tencent.supersonic.headless.chat.mapper.SchemaMapper; +import com.tencent.supersonic.headless.chat.parser.SemanticParser; +import com.tencent.supersonic.headless.chat.query.QueryManager; +import com.tencent.supersonic.headless.chat.query.SemanticQuery; +import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; +import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; +import com.tencent.supersonic.headless.server.processor.ResultProcessor; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.collections.MapUtils; +import org.apache.commons.lang3.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +@Service +@Slf4j +public class ChatWorkflowEngine { + + private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); + private List schemaMappers = ComponentFactory.getSchemaMappers(); + private List semanticParsers = ComponentFactory.getSemanticParsers(); + private List semanticCorrectors = ComponentFactory.getSemanticCorrectors(); + private List resultProcessors = ComponentFactory.getResultProcessors(); + + public void execute(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) { + queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING); + while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) { + switch (queryCtx.getChatWorkflowState()) { + case MAPPING: + performMapping(queryCtx); + queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING); + break; + case PARSING: + performParsing(queryCtx, chatCtx); + queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING); + break; + case CORRECTING: + performCorrecting(queryCtx); + queryCtx.setChatWorkflowState(ChatWorkflowState.TRANSLATING); + break; + case TRANSLATING: + long start = System.currentTimeMillis(); + performTranslating(queryCtx); + parseResult.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start); + queryCtx.setChatWorkflowState(ChatWorkflowState.PROCESSING); + break; + case PROCESSING: + default: + performProcessing(queryCtx, chatCtx, parseResult); + queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); + break; + } + } + } + + public void performMapping(ChatQueryContext queryCtx) { + if (Objects.isNull(queryCtx.getMapInfo()) + || MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) { + schemaMappers.forEach(mapper -> mapper.map(queryCtx)); + } + } + + public void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) { + semanticParsers.forEach(parser -> { + parser.parse(queryCtx, chatCtx); + log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); + }); + } + + public void performCorrecting(ChatQueryContext queryCtx) { + List candidateQueries = queryCtx.getCandidateQueries(); + if (CollectionUtils.isNotEmpty(candidateQueries)) { + for (SemanticQuery semanticQuery : candidateQueries) { + if (semanticQuery instanceof RuleSemanticQuery) { + continue; + } + for (SemanticCorrector corrector : semanticCorrectors) { + corrector.correct(queryCtx, semanticQuery.getParseInfo()); + if (!ChatWorkflowState.CORRECTING.equals(queryCtx.getChatWorkflowState())) { + break; + } + } + } + } + } + + public void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) { + resultProcessors.forEach(processor -> { + processor.process(parseResult, queryCtx, chatCtx); + }); + } + + private void performTranslating(ChatQueryContext chatQueryContext) { + List semanticParseInfos = chatQueryContext.getCandidateQueries().stream() + .map(SemanticQuery::getParseInfo) + .collect(Collectors.toList()); + + semanticParseInfos.forEach(parseInfo -> { + try { + SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode()); + if (Objects.isNull(semanticQuery)) { + return; + } + semanticQuery.setParseInfo(parseInfo); + SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); + SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class); + TranslateSqlReq translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq) + .queryTypeEnum(QueryMethod.SQL).build(); + TranslateResp explain = queryService.translate(translateSqlReq, chatQueryContext.getUser()); + String querySql = explain.getSql(); + if (StringUtils.isBlank(querySql)) { + return; + } + SqlInfo sqlInfo = parseInfo.getSqlInfo(); + sqlInfo.setQuerySQL(querySql); + sqlInfo.setSourceId(explain.getSourceId()); + + keyPipelineLog.info("SqlInfoProcessor results:\n" + + "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}", + StringUtils.normalizeSpace(sqlInfo.getS2SQL()), + StringUtils.normalizeSpace(sqlInfo.getCorrectS2SQL()), + StringUtils.normalizeSpace(querySql)); + } catch (Exception e) { + log.warn("get sql info failed:{}", parseInfo, e); + } + }); + } +} \ No newline at end of file diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/WorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/WorkflowEngine.java deleted file mode 100644 index 21baeced4..000000000 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/WorkflowEngine.java +++ /dev/null @@ -1,91 +0,0 @@ -package com.tencent.supersonic.headless.server.utils; - -import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState; -import com.tencent.supersonic.headless.api.pojo.response.ParseResp; -import com.tencent.supersonic.headless.chat.ChatContext; -import com.tencent.supersonic.headless.chat.QueryContext; -import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector; -import com.tencent.supersonic.headless.chat.mapper.SchemaMapper; -import com.tencent.supersonic.headless.chat.parser.SemanticParser; -import com.tencent.supersonic.headless.chat.query.SemanticQuery; -import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; -import com.tencent.supersonic.headless.server.processor.ResultProcessor; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.collections.MapUtils; -import org.springframework.stereotype.Service; - -import java.util.List; -import java.util.Objects; - -@Service -@Slf4j -public class WorkflowEngine { - private List schemaMappers = ComponentFactory.getSchemaMappers(); - private List semanticParsers = ComponentFactory.getSemanticParsers(); - private List semanticCorrectors = ComponentFactory.getSemanticCorrectors(); - private List resultProcessors = ComponentFactory.getResultProcessors(); - - public void startWorkflow(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) { - queryCtx.setWorkflowState(WorkflowState.MAPPING); - while (queryCtx.getWorkflowState() != WorkflowState.FINISHED) { - switch (queryCtx.getWorkflowState()) { - case MAPPING: - performMapping(queryCtx); - queryCtx.setWorkflowState(WorkflowState.PARSING); - break; - case PARSING: - performParsing(queryCtx, chatCtx); - queryCtx.setWorkflowState(WorkflowState.CORRECTING); - break; - case CORRECTING: - performCorrecting(queryCtx); - queryCtx.setWorkflowState(WorkflowState.PROCESSING); - break; - case PROCESSING: - default: - performProcessing(queryCtx, chatCtx, parseResult); - queryCtx.setWorkflowState(WorkflowState.FINISHED); - break; - } - } - } - - public void performMapping(QueryContext queryCtx) { - if (Objects.isNull(queryCtx.getMapInfo()) - || MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) { - schemaMappers.forEach(mapper -> mapper.map(queryCtx)); - } - } - - public void performParsing(QueryContext queryCtx, ChatContext chatCtx) { - semanticParsers.forEach(parser -> { - parser.parse(queryCtx, chatCtx); - log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); - }); - } - - public void performCorrecting(QueryContext queryCtx) { - List candidateQueries = queryCtx.getCandidateQueries(); - if (CollectionUtils.isNotEmpty(candidateQueries)) { - for (SemanticQuery semanticQuery : candidateQueries) { - if (semanticQuery instanceof RuleSemanticQuery) { - continue; - } - for (SemanticCorrector corrector : semanticCorrectors) { - corrector.correct(queryCtx, semanticQuery.getParseInfo()); - if (!WorkflowState.CORRECTING.equals(queryCtx.getWorkflowState())) { - break; - } - } - } - } - } - - public void performProcessing(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) { - resultProcessors.forEach(processor -> { - processor.process(parseResult, queryCtx, chatCtx); - }); - } -} \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index e7e908e85..c78ec10c0 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -47,9 +47,7 @@ com.tencent.supersonic.headless.core.cache.QueryCache=\ ### headless-server SPIs com.tencent.supersonic.headless.server.processor.ResultProcessor=\ - com.tencent.supersonic.headless.server.processor.ParseInfoProcessor, \ - com.tencent.supersonic.headless.server.processor.SqlInfoProcessor - + com.tencent.supersonic.headless.server.processor.ParseInfoProcessor ### chat-server SPIs