mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
(improvement)(headless)Add explicit TRANSLATING stage and rename several classes by the way.
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.enums;
|
package com.tencent.supersonic.headless.api.pojo.enums;
|
||||||
|
|
||||||
public enum WorkflowState {
|
public enum ChatWorkflowState {
|
||||||
MAPPING,
|
MAPPING,
|
||||||
PARSING,
|
PARSING,
|
||||||
CORRECTING,
|
CORRECTING,
|
||||||
@@ -4,14 +4,14 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
|
|||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.common.config.ModelConfig;
|
import com.tencent.supersonic.common.config.ModelConfig;
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
|
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
@@ -20,6 +20,7 @@ import lombok.Builder;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -31,7 +32,7 @@ import java.util.stream.Collectors;
|
|||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class QueryContext {
|
public class ChatQueryContext {
|
||||||
|
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
@@ -39,6 +40,7 @@ public class QueryContext {
|
|||||||
private Map<Long, List<Long>> modelIdToDataSetIds;
|
private Map<Long, List<Long>> modelIdToDataSetIds;
|
||||||
private User user;
|
private User user;
|
||||||
private boolean saveAnswer;
|
private boolean saveAnswer;
|
||||||
|
@Builder.Default
|
||||||
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||||
@@ -47,7 +49,7 @@ public class QueryContext {
|
|||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
private SemanticSchema semanticSchema;
|
private SemanticSchema semanticSchema;
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
private WorkflowState workflowState;
|
private ChatWorkflowState chatWorkflowState;
|
||||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||||
private ModelConfig modelConfig;
|
private ModelConfig modelConfig;
|
||||||
private PromptConfig promptConfig;
|
private PromptConfig promptConfig;
|
||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.chat.corrector;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -16,17 +16,17 @@ import java.util.List;
|
|||||||
public class AggCorrector extends BaseSemanticCorrector {
|
public class AggCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
addAggregate(queryContext, semanticParseInfo);
|
addAggregate(chatQueryContext, semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void addAggregate(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
||||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
addAggregateToMetric(queryContext, semanticParseInfo);
|
addAggregateToMetric(chatQueryContext, semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -26,23 +27,23 @@ import org.springframework.util.CollectionUtils;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||||
|
|
||||||
public void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
try {
|
try {
|
||||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
doCorrect(queryContext, semanticParseInfo);
|
doCorrect(chatQueryContext, semanticParseInfo);
|
||||||
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
public abstract void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo);
|
||||||
|
|
||||||
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long dataSetId) {
|
protected Map<String, String> getFieldNameMap(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||||
|
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||||
|
|
||||||
List<SchemaElement> dbAllFields = new ArrayList<>();
|
List<SchemaElement> dbAllFields = new ArrayList<>();
|
||||||
dbAllFields.addAll(semanticSchema.getMetrics());
|
dbAllFields.addAll(semanticSchema.getMetrics());
|
||||||
@@ -71,11 +72,11 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
//add aggregate to all metric
|
//add aggregate to all metric
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
|
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
|
||||||
List<SchemaElement> metrics = getMetricElements(queryContext, dataSetId);
|
List<SchemaElement> metrics = getMetricElements(chatQueryContext, dataSetId);
|
||||||
|
|
||||||
Map<String, String> metricToAggregate = metrics.stream()
|
Map<String, String> metricToAggregate = metrics.stream()
|
||||||
.map(schemaElement -> {
|
.map(schemaElement -> {
|
||||||
@@ -100,8 +101,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long dataSetId) {
|
protected List<SchemaElement> getMetricElements(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||||
return semanticSchema.getMetrics(dataSetId);
|
return semanticSchema.getMetrics(dataSetId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.corrector;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -26,9 +26,9 @@ public class GrammarCorrector extends BaseSemanticCorrector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
for (BaseSemanticCorrector corrector : correctors) {
|
for (BaseSemanticCorrector corrector : correctors) {
|
||||||
corrector.correct(queryContext, semanticParseInfo);
|
corrector.correct(chatQueryContext, semanticParseInfo);
|
||||||
}
|
}
|
||||||
removeSameFieldFromSelect(semanticParseInfo);
|
removeSameFieldFromSelect(semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.core.env.Environment;
|
import org.springframework.core.env.Environment;
|
||||||
@@ -23,20 +23,20 @@ import java.util.stream.Collectors;
|
|||||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Boolean needAddGroupBy = needAddGroupBy(queryContext, semanticParseInfo);
|
Boolean needAddGroupBy = needAddGroupBy(chatQueryContext, semanticParseInfo);
|
||||||
if (!needAddGroupBy) {
|
if (!needAddGroupBy) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
addGroupByFields(queryContext, semanticParseInfo);
|
addGroupByFields(chatQueryContext, semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private Boolean needAddGroupBy(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||||
//add dimension group by
|
//add dimension group by
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||||
// check has distinct
|
// check has distinct
|
||||||
if (SqlSelectHelper.hasDistinct(correctS2SQL)) {
|
if (SqlSelectHelper.hasDistinct(correctS2SQL)) {
|
||||||
log.debug("no need to add groupby ,existed distinct in s2sql:{}", correctS2SQL);
|
log.debug("no need to add groupby ,existed distinct in s2sql:{}", correctS2SQL);
|
||||||
@@ -64,12 +64,12 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void addGroupByFields(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||||
//add dimension group by
|
//add dimension group by
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||||
//add alias field name
|
//add alias field name
|
||||||
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
||||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
|||||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -24,10 +24,10 @@ import java.util.stream.Collectors;
|
|||||||
public class HavingCorrector extends BaseSemanticCorrector {
|
public class HavingCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
//add aggregate to all metric
|
//add aggregate to all metric
|
||||||
addHaving(queryContext, semanticParseInfo);
|
addHaving(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
//decide whether add having expression field to select
|
//decide whether add having expression field to select
|
||||||
Environment environment = ContextUtils.getBean(Environment.class);
|
Environment environment = ContextUtils.getBean(Environment.class);
|
||||||
@@ -38,10 +38,10 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void addHaving(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Long dataSet = semanticParseInfo.getDataSet().getDataSet();
|
Long dataSet = semanticParseInfo.getDataSet().getDataSet();
|
||||||
|
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||||
|
|
||||||
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
|
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
|
||||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||||
|
|||||||
@@ -6,19 +6,19 @@ import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
|
|||||||
import com.tencent.supersonic.common.util.DateUtils;
|
import com.tencent.supersonic.common.util.DateUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
public class S2SqlDateHelper {
|
public class S2SqlDateHelper {
|
||||||
|
|
||||||
public static String getReferenceDate(QueryContext queryContext, Long dataSetId) {
|
public static String getReferenceDate(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||||
String defaultDate = DateUtils.getBeforeDate(0);
|
String defaultDate = DateUtils.getBeforeDate(0);
|
||||||
if (Objects.isNull(dataSetId)) {
|
if (Objects.isNull(dataSetId)) {
|
||||||
return defaultDate;
|
return defaultDate;
|
||||||
}
|
}
|
||||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||||
if (dataSetSchema == null || dataSetSchema.getTagTypeTimeDefaultConfig() == null) {
|
if (dataSetSchema == null || dataSetSchema.getTagTypeTimeDefaultConfig() == null) {
|
||||||
return defaultDate;
|
return defaultDate;
|
||||||
}
|
}
|
||||||
@@ -26,13 +26,13 @@ public class S2SqlDateHelper {
|
|||||||
return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig).getLeft();
|
return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig).getLeft();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Pair<String, String> getStartEndDate(QueryContext queryContext, Long dataSetId,
|
public static Pair<String, String> getStartEndDate(ChatQueryContext chatQueryContext, Long dataSetId,
|
||||||
QueryType queryType) {
|
QueryType queryType) {
|
||||||
String defaultDate = DateUtils.getBeforeDate(0);
|
String defaultDate = DateUtils.getBeforeDate(0);
|
||||||
if (Objects.isNull(dataSetId)) {
|
if (Objects.isNull(dataSetId)) {
|
||||||
return Pair.of(defaultDate, defaultDate);
|
return Pair.of(defaultDate, defaultDate);
|
||||||
}
|
}
|
||||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||||
if (dataSetSchema == null) {
|
if (dataSetSchema == null) {
|
||||||
return Pair.of(defaultDate, defaultDate);
|
return Pair.of(defaultDate, defaultDate);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
|
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -34,7 +34,7 @@ import java.util.stream.Collectors;
|
|||||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
correctAggFunction(semanticParseInfo);
|
correctAggFunction(semanticParseInfo);
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
updateFieldValueByLinkingValue(semanticParseInfo);
|
updateFieldValueByLinkingValue(semanticParseInfo);
|
||||||
|
|
||||||
correctFieldName(queryContext, semanticParseInfo);
|
correctFieldName(chatQueryContext, semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||||
@@ -60,8 +60,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void correctFieldName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getDataSetId());
|
Map<String, String> fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
@@ -115,7 +115,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext,
|
||||||
|
SemanticParseInfo semanticParseInfo) {
|
||||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
||||||
@@ -123,7 +124,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
List<LLMReq.ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
|
List<LLMReq.ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||||
Set<String> dimensions = getDimensions(semanticParseInfo.getDataSetId(), semanticSchema);
|
Set<String> dimensions = getDimensions(semanticParseInfo.getDataSetId(), semanticSchema);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(linkingValues)) {
|
if (CollectionUtils.isEmpty(linkingValues)) {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import com.tencent.supersonic.common.pojo.enums.QueryType;
|
|||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.core.env.Environment;
|
import org.springframework.core.env.Environment;
|
||||||
@@ -32,7 +32,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
|||||||
public static final String ADDITIONAL_INFORMATION = "s2.corrector.additional.information";
|
public static final String ADDITIONAL_INFORMATION = "s2.corrector.additional.information";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||||
@@ -42,14 +42,14 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
|||||||
&& aggregateFields.size() == selectFields.size()) {
|
&& aggregateFields.size() == selectFields.size()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
correctS2SQL = addFieldsToSelect(queryContext, semanticParseInfo, correctS2SQL);
|
correctS2SQL = addFieldsToSelect(chatQueryContext, semanticParseInfo, correctS2SQL);
|
||||||
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
|
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected String addFieldsToSelect(QueryContext queryContext, SemanticParseInfo semanticParseInfo,
|
protected String addFieldsToSelect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo,
|
||||||
String correctS2SQL) {
|
String correctS2SQL) {
|
||||||
correctS2SQL = addTagDefaultFields(queryContext, semanticParseInfo, correctS2SQL);
|
correctS2SQL = addTagDefaultFields(chatQueryContext, semanticParseInfo, correctS2SQL);
|
||||||
|
|
||||||
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||||
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||||
@@ -69,7 +69,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
|||||||
return addFieldsToSelectSql;
|
return addFieldsToSelectSql;
|
||||||
}
|
}
|
||||||
|
|
||||||
private String addTagDefaultFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo,
|
private String addTagDefaultFields(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo,
|
||||||
String correctS2SQL) {
|
String correctS2SQL) {
|
||||||
//If it is in DETAIL mode and select *, add default metrics and dimensions.
|
//If it is in DETAIL mode and select *, add default metrics and dimensions.
|
||||||
boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL);
|
boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL);
|
||||||
@@ -77,7 +77,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
|||||||
return correctS2SQL;
|
return correctS2SQL;
|
||||||
}
|
}
|
||||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||||
Set<String> needAddDefaultFields = new HashSet<>();
|
Set<String> needAddDefaultFields = new HashSet<>();
|
||||||
if (Objects.nonNull(dataSetSchema)) {
|
if (Objects.nonNull(dataSetSchema)) {
|
||||||
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultMetrics())) {
|
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultMetrics())) {
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.corrector;
|
|||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A semantic corrector checks validity of extracted semantic information and
|
* A semantic corrector checks validity of extracted semantic information and
|
||||||
@@ -10,5 +10,5 @@ import com.tencent.supersonic.headless.chat.QueryContext;
|
|||||||
*/
|
*/
|
||||||
public interface SemanticCorrector {
|
public interface SemanticCorrector {
|
||||||
|
|
||||||
void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
void correct(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
|||||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||||
import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo;
|
import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
@@ -32,11 +32,11 @@ import java.util.Set;
|
|||||||
public class TimeCorrector extends BaseSemanticCorrector {
|
public class TimeCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
addDateIfNotExist(queryContext, semanticParseInfo);
|
addDateIfNotExist(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
removeDateIfExist(queryContext, semanticParseInfo);
|
removeDateIfExist(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
parserDateDiffFunction(semanticParseInfo);
|
parserDateDiffFunction(semanticParseInfo);
|
||||||
|
|
||||||
@@ -44,7 +44,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void removeDateIfExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
//decide whether remove date field from where
|
//decide whether remove date field from where
|
||||||
Environment environment = ContextUtils.getBean(Environment.class);
|
Environment environment = ContextUtils.getBean(Environment.class);
|
||||||
@@ -59,7 +59,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||||
|
|
||||||
@@ -71,7 +71,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
|||||||
}
|
}
|
||||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||||
|
|
||||||
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
|
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext,
|
||||||
semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType());
|
semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType());
|
||||||
|
|
||||||
if (StringUtils.isNotBlank(startEndDate.getLeft())
|
if (StringUtils.isNotBlank(startEndDate.getLeft())
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.utils.QueryFilterParser;
|
import com.tencent.supersonic.headless.chat.utils.QueryFilterParser;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
@@ -29,15 +29,15 @@ import java.util.Objects;
|
|||||||
public class WhereCorrector extends BaseSemanticCorrector {
|
public class WhereCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
addQueryFilter(queryContext, semanticParseInfo);
|
addQueryFilter(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
updateFieldValueByTechName(chatQueryContext, semanticParseInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
String queryFilter = getQueryFilter(queryContext.getQueryFilters());
|
String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters());
|
||||||
|
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
|
||||||
@@ -61,8 +61,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
|||||||
return QueryFilterParser.parse(queryFilters);
|
return QueryFilterParser.parse(queryFilters);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
private void updateFieldValueByTechName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
|
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
@@ -26,37 +26,37 @@ import java.util.stream.Collectors;
|
|||||||
public abstract class BaseMapper implements SchemaMapper {
|
public abstract class BaseMapper implements SchemaMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void map(QueryContext queryContext) {
|
public void map(ChatQueryContext chatQueryContext) {
|
||||||
|
|
||||||
String simpleName = this.getClass().getSimpleName();
|
String simpleName = this.getClass().getSimpleName();
|
||||||
long startTime = System.currentTimeMillis();
|
long startTime = System.currentTimeMillis();
|
||||||
log.debug("before {},mapInfo:{}", simpleName,
|
log.debug("before {},mapInfo:{}", simpleName,
|
||||||
queryContext.getMapInfo().getDataSetElementMatches());
|
chatQueryContext.getMapInfo().getDataSetElementMatches());
|
||||||
|
|
||||||
try {
|
try {
|
||||||
doMap(queryContext);
|
doMap(chatQueryContext);
|
||||||
filter(queryContext);
|
filter(chatQueryContext);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("work error", e);
|
log.error("work error", e);
|
||||||
}
|
}
|
||||||
|
|
||||||
long cost = System.currentTimeMillis() - startTime;
|
long cost = System.currentTimeMillis() - startTime;
|
||||||
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost,
|
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost,
|
||||||
queryContext.getMapInfo().getDataSetElementMatches());
|
chatQueryContext.getMapInfo().getDataSetElementMatches());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void filter(QueryContext queryContext) {
|
private void filter(ChatQueryContext chatQueryContext) {
|
||||||
filterByDataSetId(queryContext);
|
filterByDataSetId(chatQueryContext);
|
||||||
filterByDetectWordLenLessThanOne(queryContext);
|
filterByDetectWordLenLessThanOne(chatQueryContext);
|
||||||
switch (queryContext.getQueryDataType()) {
|
switch (chatQueryContext.getQueryDataType()) {
|
||||||
case TAG:
|
case TAG:
|
||||||
filterByQueryDataType(queryContext, element -> !(element.getIsTag() > 0));
|
filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0));
|
||||||
break;
|
break;
|
||||||
case METRIC:
|
case METRIC:
|
||||||
filterByQueryDataType(queryContext, element -> !SchemaElementType.METRIC.equals(element.getType()));
|
filterByQueryDataType(chatQueryContext, element -> !SchemaElementType.METRIC.equals(element.getType()));
|
||||||
break;
|
break;
|
||||||
case DIMENSION:
|
case DIMENSION:
|
||||||
filterByQueryDataType(queryContext, element -> {
|
filterByQueryDataType(chatQueryContext, element -> {
|
||||||
boolean isDimensionOrValue = SchemaElementType.DIMENSION.equals(element.getType())
|
boolean isDimensionOrValue = SchemaElementType.DIMENSION.equals(element.getType())
|
||||||
|| SchemaElementType.VALUE.equals(element.getType());
|
|| SchemaElementType.VALUE.equals(element.getType());
|
||||||
return !isDimensionOrValue;
|
return !isDimensionOrValue;
|
||||||
@@ -68,22 +68,22 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void filterByDataSetId(QueryContext queryContext) {
|
private static void filterByDataSetId(ChatQueryContext chatQueryContext) {
|
||||||
Set<Long> dataSetIds = queryContext.getDataSetIds();
|
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
|
||||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Set<Long> dataSetIdInMapInfo = new HashSet<>(queryContext.getMapInfo().getDataSetElementMatches().keySet());
|
Set<Long> dataSetIdInMapInfo = new HashSet<>(chatQueryContext.getMapInfo().getDataSetElementMatches().keySet());
|
||||||
for (Long dataSetId : dataSetIdInMapInfo) {
|
for (Long dataSetId : dataSetIdInMapInfo) {
|
||||||
if (!dataSetIds.contains(dataSetId)) {
|
if (!dataSetIds.contains(dataSetId)) {
|
||||||
queryContext.getMapInfo().getDataSetElementMatches().remove(dataSetId);
|
chatQueryContext.getMapInfo().getDataSetElementMatches().remove(dataSetId);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void filterByDetectWordLenLessThanOne(QueryContext queryContext) {
|
private static void filterByDetectWordLenLessThanOne(ChatQueryContext chatQueryContext) {
|
||||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
||||||
queryContext.getMapInfo().getDataSetElementMatches();
|
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
|
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
|
||||||
List<SchemaElementMatch> value = entry.getValue();
|
List<SchemaElementMatch> value = entry.getValue();
|
||||||
if (!CollectionUtils.isEmpty(value)) {
|
if (!CollectionUtils.isEmpty(value)) {
|
||||||
@@ -93,8 +93,9 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static void filterByQueryDataType(QueryContext queryContext, Predicate<SchemaElement> needRemovePredicate) {
|
private static void filterByQueryDataType(ChatQueryContext chatQueryContext,
|
||||||
queryContext.getMapInfo().getDataSetElementMatches().values().stream().forEach(
|
Predicate<SchemaElement> needRemovePredicate) {
|
||||||
|
chatQueryContext.getMapInfo().getDataSetElementMatches().values().stream().forEach(
|
||||||
schemaElementMatches -> schemaElementMatches.removeIf(
|
schemaElementMatches -> schemaElementMatches.removeIf(
|
||||||
schemaElementMatch -> {
|
schemaElementMatch -> {
|
||||||
SchemaElement element = schemaElementMatch.getElement();
|
SchemaElement element = schemaElementMatch.getElement();
|
||||||
@@ -108,7 +109,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract void doMap(QueryContext queryContext);
|
public abstract void doMap(ChatQueryContext chatQueryContext);
|
||||||
|
|
||||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId, SchemaElementMatch newElementMatch) {
|
public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId, SchemaElementMatch newElementMatch) {
|
||||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = schemaMap.getDataSetElementMatches();
|
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = schemaMap.getDataSetElementMatches();
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.chat.mapper;
|
|||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
@@ -33,25 +33,25 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
protected MapperConfig mapperConfig;
|
protected MapperConfig mapperConfig;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||||
Set<Long> detectDataSetIds) {
|
Set<Long> detectDataSetIds) {
|
||||||
String text = queryContext.getQueryText();
|
String text = chatQueryContext.getQueryText();
|
||||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug("terms:{},,detectDataSetIds:{}", terms, detectDataSetIds);
|
log.debug("terms:{},,detectDataSetIds:{}", terms, detectDataSetIds);
|
||||||
|
|
||||||
List<T> detects = detect(queryContext, terms, detectDataSetIds);
|
List<T> detects = detect(chatQueryContext, terms, detectDataSetIds);
|
||||||
Map<MatchText, List<T>> result = new HashMap<>();
|
Map<MatchText, List<T>> result = new HashMap<>();
|
||||||
|
|
||||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||||
String text = queryContext.getQueryText();
|
String text = chatQueryContext.getQueryText();
|
||||||
Set<T> results = new HashSet<>();
|
Set<T> results = new HashSet<>();
|
||||||
|
|
||||||
Set<String> detectSegments = new HashSet<>();
|
Set<String> detectSegments = new HashSet<>();
|
||||||
@@ -64,16 +64,16 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
if (index <= text.length()) {
|
if (index <= text.length()) {
|
||||||
String detectSegment = text.substring(startIndex, index).trim();
|
String detectSegment = text.substring(startIndex, index).trim();
|
||||||
detectSegments.add(detectSegment);
|
detectSegments.add(detectSegment);
|
||||||
detectByStep(queryContext, results, detectDataSetIds, detectSegment, offset);
|
detectByStep(chatQueryContext, results, detectDataSetIds, detectSegment, offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||||
}
|
}
|
||||||
detectByBatch(queryContext, results, detectDataSetIds, detectSegments);
|
detectByBatch(chatQueryContext, results, detectDataSetIds, detectSegments);
|
||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectDataSetIds,
|
protected void detectByBatch(ChatQueryContext chatQueryContext, Set<T> results, Set<Long> detectDataSetIds,
|
||||||
Set<String> detectSegments) {
|
Set<String> detectSegments) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,10 +108,10 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<T> getMatches(QueryContext queryContext, List<S2Term> terms) {
|
public List<T> getMatches(ChatQueryContext chatQueryContext, List<S2Term> terms) {
|
||||||
Set<Long> dataSetIds = queryContext.getDataSetIds();
|
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
|
||||||
terms = filterByDataSetId(terms, dataSetIds);
|
terms = filterByDataSetId(terms, dataSetIds);
|
||||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, dataSetIds);
|
Map<MatchText, List<T>> matchResult = match(chatQueryContext, terms, dataSetIds);
|
||||||
List<T> matches = new ArrayList<>();
|
List<T> matches = new ArrayList<>();
|
||||||
if (Objects.isNull(matchResult)) {
|
if (Objects.isNull(matchResult)) {
|
||||||
return matches;
|
return matches;
|
||||||
@@ -155,8 +155,8 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
|
|
||||||
public abstract String getMapKey(T a);
|
public abstract String getMapKey(T a);
|
||||||
|
|
||||||
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectDataSetIds,
|
public abstract void detectByStep(ChatQueryContext chatQueryContext, Set<T> existResults,
|
||||||
String detectSegment, int offset);
|
Set<Long> detectDataSetIds, String detectSegment, int offset);
|
||||||
|
|
||||||
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
||||||
double decreaseAmount = (threshold - minThreshold) / 4;
|
double decreaseAmount = (threshold - minThreshold) / 4;
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.Constants;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult;
|
import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -31,10 +31,10 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
private List<SchemaElement> allElements;
|
private List<SchemaElement> allElements;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
public Map<MatchText, List<DatabaseMapResult>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||||
Set<Long> detectDataSetIds) {
|
Set<Long> detectDataSetIds) {
|
||||||
this.allElements = getSchemaElements(queryContext);
|
this.allElements = getSchemaElements(chatQueryContext);
|
||||||
return super.match(queryContext, terms, detectDataSetIds);
|
return super.match(chatQueryContext, terms, detectDataSetIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -49,13 +49,13 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectDataSetIds,
|
public void detectByStep(ChatQueryContext chatQueryContext, Set<DatabaseMapResult> existResults,
|
||||||
String detectSegment, int offset) {
|
Set<Long> detectDataSetIds, String detectSegment, int offset) {
|
||||||
if (StringUtils.isBlank(detectSegment)) {
|
if (StringUtils.isBlank(detectSegment)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
Double metricDimensionThresholdConfig = getThreshold(chatQueryContext);
|
||||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||||
|
|
||||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||||
@@ -80,18 +80,19 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<SchemaElement> getSchemaElements(QueryContext queryContext) {
|
private List<SchemaElement> getSchemaElements(ChatQueryContext chatQueryContext) {
|
||||||
List<SchemaElement> allElements = new ArrayList<>();
|
List<SchemaElement> allElements = new ArrayList<>();
|
||||||
allElements.addAll(queryContext.getSemanticSchema().getDimensions());
|
allElements.addAll(chatQueryContext.getSemanticSchema().getDimensions());
|
||||||
allElements.addAll(queryContext.getSemanticSchema().getMetrics());
|
allElements.addAll(chatQueryContext.getSemanticSchema().getMetrics());
|
||||||
return allElements;
|
return allElements;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Double getThreshold(QueryContext queryContext) {
|
private Double getThreshold(ChatQueryContext chatQueryContext) {
|
||||||
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
||||||
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
|
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
|
||||||
|
|
||||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches();
|
Map<Long, List<SchemaElementMatch>> modelElementMatches = chatQueryContext.getMapInfo()
|
||||||
|
.getDataSetElementMatches();
|
||||||
|
|
||||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||||
|
|
||||||
@@ -100,7 +101,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
log.debug("ModelElementMatches:{},not exist Element threshold reduce by half:{}",
|
log.debug("ModelElementMatches:{},not exist Element threshold reduce by half:{}",
|
||||||
modelElementMatches, threshold);
|
modelElementMatches, threshold);
|
||||||
}
|
}
|
||||||
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
|
return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum());
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
|||||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -23,13 +23,13 @@ import java.util.Objects;
|
|||||||
public class EmbeddingMapper extends BaseMapper {
|
public class EmbeddingMapper extends BaseMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(ChatQueryContext chatQueryContext) {
|
||||||
//1. query from embedding by queryText
|
//1. query from embedding by queryText
|
||||||
String queryText = queryContext.getQueryText();
|
String queryText = chatQueryContext.getQueryText();
|
||||||
List<S2Term> terms = HanlpHelper.getTerms(queryText, queryContext.getModelIdToDataSetIds());
|
List<S2Term> terms = HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
||||||
|
|
||||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||||
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
List<EmbeddingResult> matchResults = matchStrategy.getMatches(chatQueryContext, terms);
|
||||||
|
|
||||||
HanlpHelper.transLetterOriginal(matchResults);
|
HanlpHelper.transLetterOriginal(matchResults);
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
}
|
}
|
||||||
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||||
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
|
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
|
||||||
queryContext.getSemanticSchema());
|
chatQueryContext.getSemanticSchema());
|
||||||
if (schemaElement == null) {
|
if (schemaElement == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -54,7 +54,7 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
.detectWord(matchResult.getDetectWord())
|
.detectWord(matchResult.getDetectWord())
|
||||||
.build();
|
.build();
|
||||||
//3. add to mapInfo
|
//3. add to mapInfo
|
||||||
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.Constants;
|
|||||||
import dev.langchain4j.store.embedding.Retrieval;
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -49,13 +49,13 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults,
|
public void detectByStep(ChatQueryContext chatQueryContext, Set<EmbeddingResult> existResults,
|
||||||
Set<Long> detectDataSetIds, String detectSegment, int offset) {
|
Set<Long> detectDataSetIds, String detectSegment, int offset) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results,
|
protected void detectByBatch(ChatQueryContext chatQueryContext, Set<EmbeddingResult> results,
|
||||||
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
||||||
int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN));
|
int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN));
|
||||||
int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX));
|
int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX));
|
||||||
@@ -72,16 +72,16 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
embeddingMapperBatch);
|
embeddingMapperBatch);
|
||||||
|
|
||||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext);
|
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||||
List<String> queryTextsSub, QueryContext queryContext) {
|
List<String> queryTextsSub, ChatQueryContext chatQueryContext) {
|
||||||
Map<Long, List<Long>> modelIdToDataSetIds = queryContext.getModelIdToDataSetIds();
|
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
|
||||||
double embeddingThreshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
double embeddingThreshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
||||||
double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
|
double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
|
||||||
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, queryContext.getMapModeEnum());
|
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, chatQueryContext.getMapModeEnum());
|
||||||
|
|
||||||
// step1. build query params
|
// step1. build query params
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -21,14 +21,14 @@ import java.util.stream.Collectors;
|
|||||||
public class EntityMapper extends BaseMapper {
|
public class EntityMapper extends BaseMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(ChatQueryContext chatQueryContext) {
|
||||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo();
|
||||||
for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) {
|
for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) {
|
||||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(dataSetId);
|
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(dataSetId);
|
||||||
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
SchemaElement entity = getEntity(dataSetId, queryContext);
|
SchemaElement entity = getEntity(dataSetId, chatQueryContext);
|
||||||
if (entity == null || entity.getId() == null) {
|
if (entity == null || entity.getId() == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -64,8 +64,8 @@ public class EntityMapper extends BaseMapper {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
private SchemaElement getEntity(Long dataSetId, QueryContext queryContext) {
|
private SchemaElement getEntity(Long dataSetId, ChatQueryContext chatQueryContext) {
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||||
DataSetSchema modelSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
DataSetSchema modelSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
||||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||||
return modelSchema.getEntity();
|
return modelSchema.getEntity();
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.mapper;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
|
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -37,16 +37,16 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
private KnowledgeBaseService knowledgeBaseService;
|
private KnowledgeBaseService knowledgeBaseService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
public Map<MatchText, List<HanlpMapResult>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||||
Set<Long> detectDataSetIds) {
|
Set<Long> detectDataSetIds) {
|
||||||
String text = queryContext.getQueryText();
|
String text = chatQueryContext.getQueryText();
|
||||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
log.debug("terms:{},detectModelIds:{}", terms, detectDataSetIds);
|
log.debug("terms:{},detectModelIds:{}", terms, detectDataSetIds);
|
||||||
|
|
||||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectDataSetIds);
|
List<HanlpMapResult> detects = detect(chatQueryContext, terms, detectDataSetIds);
|
||||||
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||||
|
|
||||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||||
@@ -59,16 +59,17 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
|
public void detectByStep(ChatQueryContext chatQueryContext, Set<HanlpMapResult> existResults,
|
||||||
|
Set<Long> detectDataSetIds,
|
||||||
String detectSegment, int offset) {
|
String detectSegment, int offset) {
|
||||||
// step1. pre search
|
// step1. pre search
|
||||||
Integer oneDetectionMaxSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_MAX_SIZE));
|
Integer oneDetectionMaxSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_MAX_SIZE));
|
||||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
||||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
oneDetectionMaxSize, chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
// step2. suffix search
|
// step2. suffix search
|
||||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(detectSegment,
|
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(detectSegment,
|
||||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
oneDetectionMaxSize, chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
@@ -83,7 +84,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
// step4. filter by similarity
|
// step4. filter by similarity
|
||||||
hanlpMapResults = hanlpMapResults.stream()
|
hanlpMapResults = hanlpMapResults.stream()
|
||||||
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||||
>= getThresholdMatch(term.getNatures(), queryContext))
|
>= getThresholdMatch(term.getNatures(), chatQueryContext))
|
||||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||||
|
|
||||||
@@ -126,7 +127,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getThresholdMatch(List<String> natures, QueryContext queryContext) {
|
public double getThresholdMatch(List<String> natures, ChatQueryContext chatQueryContext) {
|
||||||
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
||||||
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
|
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
|
||||||
if (mapperHelper.existDimensionValues(natures)) {
|
if (mapperHelper.existDimensionValues(natures)) {
|
||||||
@@ -134,7 +135,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD_MIN));
|
minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD_MIN));
|
||||||
}
|
}
|
||||||
|
|
||||||
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
|
return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult;
|
import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -30,23 +30,23 @@ import java.util.stream.Collectors;
|
|||||||
public class KeywordMapper extends BaseMapper {
|
public class KeywordMapper extends BaseMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(ChatQueryContext chatQueryContext) {
|
||||||
String queryText = queryContext.getQueryText();
|
String queryText = chatQueryContext.getQueryText();
|
||||||
//1.hanlpDict Match
|
//1.hanlpDict Match
|
||||||
List<S2Term> terms = HanlpHelper.getTerms(queryText, queryContext.getModelIdToDataSetIds());
|
List<S2Term> terms = HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
||||||
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||||
|
|
||||||
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
|
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(chatQueryContext, terms);
|
||||||
convertHanlpMapResultToMapInfo(hanlpMapResults, queryContext, terms);
|
convertHanlpMapResultToMapInfo(hanlpMapResults, chatQueryContext, terms);
|
||||||
|
|
||||||
//2.database Match
|
//2.database Match
|
||||||
DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class);
|
DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class);
|
||||||
|
|
||||||
List<DatabaseMapResult> databaseResults = databaseMatchStrategy.getMatches(queryContext, terms);
|
List<DatabaseMapResult> databaseResults = databaseMatchStrategy.getMatches(chatQueryContext, terms);
|
||||||
convertDatabaseMapResultToMapInfo(queryContext, databaseResults);
|
convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
|
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, ChatQueryContext chatQueryContext,
|
||||||
List<S2Term> terms) {
|
List<S2Term> terms) {
|
||||||
if (CollectionUtils.isEmpty(mapResults)) {
|
if (CollectionUtils.isEmpty(mapResults)) {
|
||||||
return;
|
return;
|
||||||
@@ -68,7 +68,7 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
}
|
}
|
||||||
Long elementID = NatureHelper.getElementID(nature);
|
Long elementID = NatureHelper.getElementID(nature);
|
||||||
SchemaElement element = getSchemaElement(dataSetId, elementType,
|
SchemaElement element = getSchemaElement(dataSetId, elementType,
|
||||||
elementID, queryContext.getSemanticSchema());
|
elementID, chatQueryContext.getSemanticSchema());
|
||||||
if (element == null) {
|
if (element == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -81,16 +81,17 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
.detectWord(hanlpMapResult.getDetectWord())
|
.detectWord(hanlpMapResult.getDetectWord())
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void convertDatabaseMapResultToMapInfo(QueryContext queryContext, List<DatabaseMapResult> mapResults) {
|
private void convertDatabaseMapResultToMapInfo(ChatQueryContext chatQueryContext,
|
||||||
|
List<DatabaseMapResult> mapResults) {
|
||||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||||
for (DatabaseMapResult match : mapResults) {
|
for (DatabaseMapResult match : mapResults) {
|
||||||
SchemaElement schemaElement = match.getSchemaElement();
|
SchemaElement schemaElement = match.getSchemaElement();
|
||||||
Set<Long> regElementSet = getRegElementSet(queryContext.getMapInfo(), schemaElement);
|
Set<Long> regElementSet = getRegElementSet(chatQueryContext.getMapInfo(), schemaElement);
|
||||||
if (regElementSet.contains(schemaElement.getId())) {
|
if (regElementSet.contains(schemaElement.getId())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -102,7 +103,7 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||||
.build();
|
.build();
|
||||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getDataSet(), schemaElementMatch);
|
addToSchemaMap(chatQueryContext.getMapInfo(), schemaElement.getDataSet(), schemaElementMatch);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.mapper;
|
|||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -14,6 +14,6 @@ import java.util.Set;
|
|||||||
*/
|
*/
|
||||||
public interface MatchStrategy<T> {
|
public interface MatchStrategy<T> {
|
||||||
|
|
||||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
|
Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -8,8 +8,8 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@@ -24,12 +24,12 @@ public class QueryFilterMapper extends BaseMapper {
|
|||||||
private double similarity = 1.0;
|
private double similarity = 1.0;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doMap(QueryContext queryContext) {
|
public void doMap(ChatQueryContext chatQueryContext) {
|
||||||
Set<Long> dataSetIds = queryContext.getDataSetIds();
|
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
|
||||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo();
|
||||||
clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo);
|
clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo);
|
||||||
for (Long dataSetId : dataSetIds) {
|
for (Long dataSetId : dataSetIds) {
|
||||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(dataSetId);
|
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(dataSetId);
|
||||||
@@ -37,7 +37,7 @@ public class QueryFilterMapper extends BaseMapper {
|
|||||||
schemaElementMatches = Lists.newArrayList();
|
schemaElementMatches = Lists.newArrayList();
|
||||||
schemaMapInfo.setMatchedElements(dataSetId, schemaElementMatches);
|
schemaMapInfo.setMatchedElements(dataSetId, schemaElementMatches);
|
||||||
}
|
}
|
||||||
addValueSchemaElementMatch(dataSetId, queryContext, schemaElementMatches);
|
addValueSchemaElementMatch(dataSetId, chatQueryContext, schemaElementMatches);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,9 +49,9 @@ public class QueryFilterMapper extends BaseMapper {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addValueSchemaElementMatch(Long dataSetId, QueryContext queryContext,
|
private void addValueSchemaElementMatch(Long dataSetId, ChatQueryContext chatQueryContext,
|
||||||
List<SchemaElementMatch> candidateElementMatches) {
|
List<SchemaElementMatch> candidateElementMatches) {
|
||||||
QueryFilters queryFilters = queryContext.getQueryFilters();
|
QueryFilters queryFilters = chatQueryContext.getQueryFilters();
|
||||||
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -75,7 +75,7 @@ public class QueryFilterMapper extends BaseMapper {
|
|||||||
.build();
|
.build();
|
||||||
candidateElementMatches.add(schemaElementMatch);
|
candidateElementMatches.add(schemaElementMatch);
|
||||||
}
|
}
|
||||||
queryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches);
|
chatQueryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
|
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.chat.mapper;
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)
|
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)
|
||||||
@@ -9,5 +9,5 @@ import com.tencent.supersonic.headless.chat.QueryContext;
|
|||||||
*/
|
*/
|
||||||
public interface SchemaMapper {
|
public interface SchemaMapper {
|
||||||
|
|
||||||
void map(QueryContext queryContext);
|
void map(ChatQueryContext chatQueryContext);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.chat.mapper;
|
|||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
|
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.SearchService;
|
import com.tencent.supersonic.headless.chat.knowledge.SearchService;
|
||||||
@@ -32,9 +32,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
private KnowledgeBaseService knowledgeBaseService;
|
private KnowledgeBaseService knowledgeBaseService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
public Map<MatchText, List<HanlpMapResult>> match(ChatQueryContext chatQueryContext, List<S2Term> originals,
|
||||||
Set<Long> detectDataSetIds) {
|
Set<Long> detectDataSetIds) {
|
||||||
String text = queryContext.getQueryText();
|
String text = chatQueryContext.getQueryText();
|
||||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||||
|
|
||||||
List<Integer> detectIndexList = Lists.newArrayList();
|
List<Integer> detectIndexList = Lists.newArrayList();
|
||||||
@@ -58,9 +58,14 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
|
|
||||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||||
List<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
List<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
||||||
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
SearchService.SEARCH_SIZE,
|
||||||
|
chatQueryContext.getModelIdToDataSetIds(),
|
||||||
|
detectDataSetIds);
|
||||||
List<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(
|
List<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(
|
||||||
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
detectSegment,
|
||||||
|
SEARCH_SIZE,
|
||||||
|
chatQueryContext.getModelIdToDataSetIds(),
|
||||||
|
detectDataSetIds);
|
||||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
// remove entity name where search
|
// remove entity name where search
|
||||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||||
@@ -94,8 +99,8 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
|
public void detectByStep(ChatQueryContext chatQueryContext, Set<HanlpMapResult> existResults,
|
||||||
String detectSegment, int offset) {
|
Set<Long> detectDataSetIds, String detectSegment, int offset) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
|||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -29,21 +29,21 @@ import java.util.stream.Collectors;
|
|||||||
public class QueryTypeParser implements SemanticParser {
|
public class QueryTypeParser implements SemanticParser {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||||
|
|
||||||
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
|
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
|
||||||
User user = queryContext.getUser();
|
User user = chatQueryContext.getUser();
|
||||||
|
|
||||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||||
// 1.init S2SQL
|
// 1.init S2SQL
|
||||||
semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user);
|
semanticQuery.initS2Sql(chatQueryContext.getSemanticSchema(), user);
|
||||||
// 2.set queryType
|
// 2.set queryType
|
||||||
QueryType queryType = getQueryType(queryContext, semanticQuery);
|
QueryType queryType = getQueryType(chatQueryContext, semanticQuery);
|
||||||
semanticQuery.getParseInfo().setQueryType(queryType);
|
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private QueryType getQueryType(QueryContext queryContext, SemanticQuery semanticQuery) {
|
private QueryType getQueryType(ChatQueryContext chatQueryContext, SemanticQuery semanticQuery) {
|
||||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||||
@@ -51,7 +51,7 @@ public class QueryTypeParser implements SemanticParser {
|
|||||||
}
|
}
|
||||||
//1. entity queryType
|
//1. entity queryType
|
||||||
Long dataSetId = parseInfo.getDataSetId();
|
Long dataSetId = parseInfo.getDataSetId();
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL());
|
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL());
|
||||||
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
|
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -21,12 +21,12 @@ import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_TE
|
|||||||
public class SatisfactionChecker {
|
public class SatisfactionChecker {
|
||||||
|
|
||||||
// check all the parse info in candidate
|
// check all the parse info in candidate
|
||||||
public static boolean isSkip(QueryContext queryContext) {
|
public static boolean isSkip(ChatQueryContext chatQueryContext) {
|
||||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
for (SemanticQuery query : chatQueryContext.getCandidateQueries()) {
|
||||||
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (checkThreshold(queryContext.getQueryText(), query.getParseInfo())) {
|
if (checkThreshold(chatQueryContext.getQueryText(), query.getParseInfo())) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser;
|
package com.tencent.supersonic.headless.chat.parser;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A semantic parser understands user queries and generates semantic query statement.
|
* A semantic parser understands user queries and generates semantic query statement.
|
||||||
@@ -10,5 +10,5 @@ import com.tencent.supersonic.headless.chat.QueryContext;
|
|||||||
*/
|
*/
|
||||||
public interface SemanticParser {
|
public interface SemanticParser {
|
||||||
|
|
||||||
void parse(QueryContext queryContext, ChatContext chatContext);
|
void parse(ChatQueryContext chatQueryContext, ChatContext chatContext);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||||
|
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
public interface DataSetResolver {
|
public interface DataSetResolver {
|
||||||
|
|
||||||
Long resolve(QueryContext queryContext, Set<Long> restrictiveModels);
|
Long resolve(ChatQueryContext chatQueryContext, Set<Long> restrictiveModels);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -111,8 +111,8 @@ public class HeuristicDataSetResolver implements DataSetResolver {
|
|||||||
return dataSetCount;
|
return dataSetCount;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Long resolve(QueryContext queryContext, Set<Long> agentDataSetIds) {
|
public Long resolve(ChatQueryContext chatQueryContext, Set<Long> agentDataSetIds) {
|
||||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
||||||
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
|
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
|
||||||
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
|
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
|
||||||
matchedDataSets.retainAll(agentDataSetIds);
|
matchedDataSets.retainAll(agentDataSetIds);
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||||
import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
|
import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
@@ -43,7 +43,7 @@ public class LLMRequestService {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private ParserConfig parserConfig;
|
private ParserConfig parserConfig;
|
||||||
|
|
||||||
public boolean isSkip(QueryContext queryCtx) {
|
public boolean isSkip(ChatQueryContext queryCtx) {
|
||||||
if (!queryCtx.getText2SQLType().enableLLM()) {
|
if (!queryCtx.getText2SQLType().enableLLM()) {
|
||||||
log.info("not enable llm, skip");
|
log.info("not enable llm, skip");
|
||||||
return true;
|
return true;
|
||||||
@@ -57,12 +57,12 @@ public class LLMRequestService {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
public Long getDataSetId(QueryContext queryCtx) {
|
public Long getDataSetId(ChatQueryContext queryCtx) {
|
||||||
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
|
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
|
||||||
return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds());
|
return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds());
|
||||||
}
|
}
|
||||||
|
|
||||||
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId) {
|
public LLMReq getLlmReq(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||||
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId);
|
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId);
|
||||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||||
@@ -118,7 +118,7 @@ public class LLMRequestService {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
protected List<String> getFieldNameList(ChatQueryContext queryCtx, Long dataSetId,
|
||||||
LLMParserConfig llmParserConfig) {
|
LLMParserConfig llmParserConfig) {
|
||||||
|
|
||||||
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
|
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
|
||||||
@@ -129,7 +129,7 @@ public class LLMRequestService {
|
|||||||
return new ArrayList<>(results);
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<LLMReq.Term> getTerms(QueryContext queryCtx, Long dataSetId) {
|
protected List<LLMReq.Term> getTerms(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
return new ArrayList<>();
|
return new ArrayList<>();
|
||||||
@@ -147,7 +147,7 @@ public class LLMRequestService {
|
|||||||
}).collect(Collectors.toList());
|
}).collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
private String getPriorExts(QueryContext queryContext, List<String> fieldNameList) {
|
private String getPriorExts(ChatQueryContext queryContext, List<String> fieldNameList) {
|
||||||
StringBuilder extraInfoSb = new StringBuilder();
|
StringBuilder extraInfoSb = new StringBuilder();
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||||
Map<String, String> fieldNameToDataFormatType = semanticSchema.getMetrics()
|
Map<String, String> fieldNameToDataFormatType = semanticSchema.getMetrics()
|
||||||
@@ -176,7 +176,7 @@ public class LLMRequestService {
|
|||||||
return extraInfoSb.toString();
|
return extraInfoSb.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<LLMReq.ElementValue> getValues(QueryContext queryCtx, Long dataSetId) {
|
public List<LLMReq.ElementValue> getValues(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
@@ -198,14 +198,14 @@ public class LLMRequestService {
|
|||||||
return new ArrayList<>(valueMatches);
|
return new ArrayList<>(valueMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long dataSetId) {
|
protected Map<Long, String> getItemIdToName(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||||
List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId);
|
List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId);
|
||||||
return elements.stream()
|
return elements.stream()
|
||||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||||
}
|
}
|
||||||
|
|
||||||
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
|
private Set<String> getTopNFieldNames(ChatQueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
|
||||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||||
Set<String> results = new HashSet<>();
|
Set<String> results = new HashSet<>();
|
||||||
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
|
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
|
||||||
@@ -223,7 +223,7 @@ public class LLMRequestService {
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<SchemaElement> getMatchedMetrics(QueryContext queryCtx, Long dataSetId) {
|
protected List<SchemaElement> getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
@@ -240,7 +240,7 @@ public class LLMRequestService {
|
|||||||
return schemaElements;
|
return schemaElements;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected List<SchemaElement> getMatchedDimensions(QueryContext queryCtx, Long dataSetId) {
|
protected List<SchemaElement> getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
return Collections.emptyList();
|
return Collections.emptyList();
|
||||||
@@ -257,7 +257,7 @@ public class LLMRequestService {
|
|||||||
return schemaElements;
|
return schemaElements;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long dataSetId) {
|
protected Set<String> getMatchedFieldNames(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
|||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.MapUtils;
|
import org.apache.commons.collections.MapUtils;
|
||||||
@@ -22,7 +22,8 @@ import java.util.Objects;
|
|||||||
@Service
|
@Service
|
||||||
public class LLMResponseService {
|
public class LLMResponseService {
|
||||||
|
|
||||||
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) {
|
public SemanticParseInfo addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult,
|
||||||
|
String s2SQL, Double weight) {
|
||||||
if (Objects.isNull(weight)) {
|
if (Objects.isNull(weight)) {
|
||||||
weight = 0D;
|
weight = 0D;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||||
@@ -23,7 +23,7 @@ import org.apache.commons.collections.MapUtils;
|
|||||||
public class LLMSqlParser implements SemanticParser {
|
public class LLMSqlParser implements SemanticParser {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
public void parse(ChatQueryContext queryCtx, ChatContext chatCtx) {
|
||||||
try {
|
try {
|
||||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||||
//1.determine whether to skip this parser.
|
//1.determine whether to skip this parser.
|
||||||
@@ -44,7 +44,7 @@ public class LLMSqlParser implements SemanticParser {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void tryParse(QueryContext queryCtx, Long dataSetId) {
|
private void tryParse(ChatQueryContext queryCtx, Long dataSetId) {
|
||||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||||
int maxRetries = ContextUtils.getBean(LLMParserConfig.class).getRecallMaxRetries();
|
int maxRetries = ContextUtils.getBean(LLMParserConfig.class).getRecallMaxRetries();
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.rule;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
@@ -41,11 +41,11 @@ public class AggregateTypeParser implements SemanticParser {
|
|||||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
|
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||||
String queryText = queryContext.getQueryText();
|
String queryText = chatQueryContext.getQueryText();
|
||||||
AggregateConf aggregateConf = resolveAggregateConf(queryText);
|
AggregateConf aggregateConf = resolveAggregateConf(queryText);
|
||||||
|
|
||||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
for (SemanticQuery semanticQuery : chatQueryContext.getCandidateQueries()) {
|
||||||
if (!AggregateTypeEnum.NONE.equals(semanticQuery.getParseInfo().getAggType())) {
|
if (!AggregateTypeEnum.NONE.equals(semanticQuery.getParseInfo().getAggType())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,12 +2,12 @@ package com.tencent.supersonic.headless.chat.parser.rule;
|
|||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery;
|
||||||
@@ -43,16 +43,16 @@ public class ContextInheritParser implements SemanticParser {
|
|||||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||||
if (!shouldInherit(queryContext)) {
|
if (!shouldInherit(chatQueryContext)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Long dataSetId = getMatchedDataSet(queryContext, chatContext);
|
Long dataSetId = getMatchedDataSet(chatQueryContext, chatContext);
|
||||||
if (dataSetId == null) {
|
if (dataSetId == null) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo().getMatchedElements(dataSetId);
|
List<SchemaElementMatch> elementMatches = chatQueryContext.getMapInfo().getMatchedElements(dataSetId);
|
||||||
|
|
||||||
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
||||||
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
||||||
@@ -66,18 +66,18 @@ public class ContextInheritParser implements SemanticParser {
|
|||||||
}
|
}
|
||||||
elementMatches.addAll(matchesToInherit);
|
elementMatches.addAll(matchesToInherit);
|
||||||
|
|
||||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, queryContext);
|
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
|
||||||
for (RuleSemanticQuery query : queries) {
|
for (RuleSemanticQuery query : queries) {
|
||||||
query.fillParseInfo(queryContext, chatContext);
|
query.fillParseInfo(chatQueryContext, chatContext);
|
||||||
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), queryContext)) {
|
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), chatQueryContext)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
queryContext.getCandidateQueries().add(query);
|
chatQueryContext.getCandidateQueries().add(query);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean existSameQuery(Long dataSetId, String queryMode, QueryContext queryContext) {
|
private boolean existSameQuery(Long dataSetId, String queryMode, ChatQueryContext chatQueryContext) {
|
||||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
for (SemanticQuery semanticQuery : chatQueryContext.getCandidateQueries()) {
|
||||||
if (semanticQuery.getQueryMode().equals(queryMode)
|
if (semanticQuery.getQueryMode().equals(queryMode)
|
||||||
&& semanticQuery.getParseInfo().getDataSetId().equals(dataSetId)) {
|
&& semanticQuery.getParseInfo().getDataSetId().equals(dataSetId)) {
|
||||||
return true;
|
return true;
|
||||||
@@ -100,20 +100,20 @@ public class ContextInheritParser implements SemanticParser {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
protected boolean shouldInherit(QueryContext queryContext) {
|
protected boolean shouldInherit(ChatQueryContext chatQueryContext) {
|
||||||
// if candidates only have MetricModel mode, count in context
|
// if candidates only have MetricModel mode, count in context
|
||||||
List<SemanticQuery> metricModelQueries = queryContext.getCandidateQueries().stream()
|
List<SemanticQuery> metricModelQueries = chatQueryContext.getCandidateQueries().stream()
|
||||||
.filter(query -> query instanceof MetricModelQuery).collect(
|
.filter(query -> query instanceof MetricModelQuery).collect(
|
||||||
Collectors.toList());
|
Collectors.toList());
|
||||||
return metricModelQueries.size() == queryContext.getCandidateQueries().size();
|
return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Long getMatchedDataSet(QueryContext queryContext, ChatContext chatContext) {
|
protected Long getMatchedDataSet(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||||
Long dataSetId = chatContext.getParseInfo().getDataSetId();
|
Long dataSetId = chatContext.getParseInfo().getDataSetId();
|
||||||
if (dataSetId == null) {
|
if (dataSetId == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
Set<Long> queryDataSets = queryContext.getMapInfo().getMatchedDataSetInfos();
|
Set<Long> queryDataSets = chatQueryContext.getMapInfo().getMatchedDataSetInfos();
|
||||||
if (queryDataSets.contains(dataSetId)) {
|
if (queryDataSets.contains(dataSetId)) {
|
||||||
return dataSetId;
|
return dataSetId;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,10 +2,10 @@ package com.tencent.supersonic.headless.chat.parser.rule;
|
|||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -24,21 +24,21 @@ public class RuleSqlParser implements SemanticParser {
|
|||||||
);
|
);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||||
if (!queryContext.getText2SQLType().enableRule()) {
|
if (!chatQueryContext.getText2SQLType().enableRule()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
||||||
// iterate all schemaElementMatches to resolve query mode
|
// iterate all schemaElementMatches to resolve query mode
|
||||||
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
|
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
|
||||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
|
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, queryContext);
|
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
|
||||||
for (RuleSemanticQuery query : queries) {
|
for (RuleSemanticQuery query : queries) {
|
||||||
query.fillParseInfo(queryContext, chatContext);
|
query.fillParseInfo(chatQueryContext, chatContext);
|
||||||
queryContext.getCandidateQueries().add(query);
|
chatQueryContext.getCandidateQueries().add(query);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auxiliaryParsers.stream().forEach(p -> p.parse(queryContext, chatContext));
|
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext, chatContext));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.chat.parser.rule;
|
|||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
@@ -42,7 +42,7 @@ public class TimeRangeParser implements SemanticParser {
|
|||||||
private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd");
|
private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd");
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
public void parse(ChatQueryContext queryContext, ChatContext chatContext) {
|
||||||
String queryText = queryContext.getQueryText();
|
String queryText = queryContext.getQueryText();
|
||||||
DateConf dateConf = parseRecent(queryText);
|
DateConf dateConf = parseRecent(queryText);
|
||||||
if (dateConf == null) {
|
if (dateConf == null) {
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||||
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||||
@@ -40,7 +40,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
||||||
QueryContext queryCtx) {
|
ChatQueryContext queryCtx) {
|
||||||
return queryMatcher.match(candidateElementMatches);
|
return queryMatcher.match(candidateElementMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,9 +49,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
|||||||
initS2SqlByStruct(semanticSchema);
|
initS2SqlByStruct(semanticSchema);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) {
|
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||||
parseInfo.setQueryMode(getQueryMode());
|
parseInfo.setQueryMode(getQueryMode());
|
||||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||||
|
|
||||||
fillSchemaElement(parseInfo, semanticSchema);
|
fillSchemaElement(parseInfo, semanticSchema);
|
||||||
fillScore(parseInfo);
|
fillScore(parseInfo);
|
||||||
@@ -223,10 +223,10 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static List<RuleSemanticQuery> resolve(Long dataSetId, List<SchemaElementMatch> candidateElementMatches,
|
public static List<RuleSemanticQuery> resolve(Long dataSetId, List<SchemaElementMatch> candidateElementMatches,
|
||||||
QueryContext queryContext) {
|
ChatQueryContext chatQueryContext) {
|
||||||
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
|
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
|
||||||
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
|
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
|
||||||
List<SchemaElementMatch> matches = semanticQuery.match(candidateElementMatches, queryContext);
|
List<SchemaElementMatch> matches = semanticQuery.match(candidateElementMatches, chatQueryContext);
|
||||||
|
|
||||||
if (matches.size() > 0) {
|
if (matches.size() > 0) {
|
||||||
RuleSemanticQuery query = QueryManager.createRuleQuery(semanticQuery.getQueryMode());
|
RuleSemanticQuery query = QueryManager.createRuleQuery(semanticQuery.getQueryMode());
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
import java.util.LinkedHashSet;
|
import java.util.LinkedHashSet;
|
||||||
@@ -19,15 +19,15 @@ import java.util.stream.Collectors;
|
|||||||
public abstract class DetailListQuery extends DetailSemanticQuery {
|
public abstract class DetailListQuery extends DetailSemanticQuery {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) {
|
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||||
super.fillParseInfo(queryContext, chatContext);
|
super.fillParseInfo(chatQueryContext, chatContext);
|
||||||
this.addEntityDetailAndOrderByMetric(queryContext, parseInfo);
|
this.addEntityDetailAndOrderByMetric(chatQueryContext, parseInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addEntityDetailAndOrderByMetric(QueryContext queryContext, SemanticParseInfo parseInfo) {
|
private void addEntityDetailAndOrderByMetric(ChatQueryContext chatQueryContext, SemanticParseInfo parseInfo) {
|
||||||
Long dataSetId = parseInfo.getDataSetId();
|
Long dataSetId = parseInfo.getDataSetId();
|
||||||
if (Objects.nonNull(dataSetId) && dataSetId > 0L) {
|
if (Objects.nonNull(dataSetId) && dataSetId > 0L) {
|
||||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||||
if (dataSetSchema != null && Objects.nonNull(dataSetSchema.getEntity())) {
|
if (dataSetSchema != null && Objects.nonNull(dataSetSchema.getEntity())) {
|
||||||
Set<SchemaElement> dimensions = new LinkedHashSet<>();
|
Set<SchemaElement> dimensions = new LinkedHashSet<>();
|
||||||
Set<SchemaElement> metrics = new LinkedHashSet<>();
|
Set<SchemaElement> metrics = new LinkedHashSet<>();
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption;
|
import com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
@@ -30,19 +30,19 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
||||||
QueryContext queryCtx) {
|
ChatQueryContext queryCtx) {
|
||||||
return super.match(candidateElementMatches, queryCtx);
|
return super.match(candidateElementMatches, queryCtx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) {
|
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||||
super.fillParseInfo(queryContext, chatContext);
|
super.fillParseInfo(chatQueryContext, chatContext);
|
||||||
|
|
||||||
parseInfo.setQueryType(QueryType.DETAIL);
|
parseInfo.setQueryType(QueryType.DETAIL);
|
||||||
parseInfo.setLimit(DETAIL_MAX_RESULTS);
|
parseInfo.setLimit(DETAIL_MAX_RESULTS);
|
||||||
if (parseInfo.getDateInfo() == null) {
|
if (parseInfo.getDateInfo() == null) {
|
||||||
DataSetSchema dataSetSchema =
|
DataSetSchema dataSetSchema =
|
||||||
queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
|
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
|
||||||
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
|
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
|
||||||
DateConf dateInfo = new DateConf();
|
DateConf dateInfo = new DateConf();
|
||||||
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())) {
|
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())) {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -31,17 +31,17 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
||||||
QueryContext queryCtx) {
|
ChatQueryContext queryCtx) {
|
||||||
return super.match(candidateElementMatches, queryCtx);
|
return super.match(candidateElementMatches, queryCtx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) {
|
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||||
super.fillParseInfo(queryContext, chatContext);
|
super.fillParseInfo(chatQueryContext, chatContext);
|
||||||
parseInfo.setLimit(METRIC_MAX_RESULTS);
|
parseInfo.setLimit(METRIC_MAX_RESULTS);
|
||||||
if (parseInfo.getDateInfo() == null) {
|
if (parseInfo.getDateInfo() == null) {
|
||||||
DataSetSchema dataSetSchema =
|
DataSetSchema dataSetSchema =
|
||||||
queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
|
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
|
||||||
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig();
|
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig();
|
||||||
DateConf dateInfo = new DateConf();
|
DateConf dateInfo = new DateConf();
|
||||||
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())) {
|
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())) {
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.Order;
|
|||||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ public class MetricTopNQuery extends MetricSemanticQuery {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
||||||
QueryContext queryCtx) {
|
ChatQueryContext queryCtx) {
|
||||||
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getQueryText());
|
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getQueryText());
|
||||||
if (matcher.matches()) {
|
if (matcher.matches()) {
|
||||||
return super.match(candidateElementMatches, queryCtx);
|
return super.match(candidateElementMatches, queryCtx);
|
||||||
@@ -50,8 +50,8 @@ public class MetricTopNQuery extends MetricSemanticQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) {
|
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||||
super.fillParseInfo(queryContext, chatContext);
|
super.fillParseInfo(chatQueryContext, chatContext);
|
||||||
|
|
||||||
parseInfo.setLimit(ORDERBY_MAX_RESULTS);
|
parseInfo.setLimit(ORDERBY_MAX_RESULTS);
|
||||||
parseInfo.setScore(parseInfo.getScore() + 2.0);
|
parseInfo.setScore(parseInfo.getScore() + 2.0);
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ class AggCorrectorTest {
|
|||||||
void testDoCorrect() {
|
void testDoCorrect() {
|
||||||
AggCorrector corrector = new AggCorrector();
|
AggCorrector corrector = new AggCorrector();
|
||||||
Long dataSetId = 1L;
|
Long dataSetId = 1L;
|
||||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
|
||||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
SchemaElement dataSet = new SchemaElement();
|
SchemaElement dataSet = new SchemaElement();
|
||||||
dataSet.setDataSet(dataSetId);
|
dataSet.setDataSet(dataSetId);
|
||||||
@@ -33,15 +33,15 @@ class AggCorrectorTest {
|
|||||||
sqlInfo.setS2SQL(sql);
|
sqlInfo.setS2SQL(sql);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||||
corrector.correct(queryContext, semanticParseInfo);
|
corrector.correct(chatQueryContext, semanticParseInfo);
|
||||||
Assert.assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'"
|
Assert.assertEquals("SELECT 用户, SUM(访问次数) FROM 超音数数据集 WHERE 部门 = 'sales'"
|
||||||
+ " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户"
|
+ " AND datediff('day', 数据日期, '2024-06-04') <= 7 GROUP BY 用户"
|
||||||
+ " ORDER BY SUM(访问次数) DESC LIMIT 1",
|
+ " ORDER BY SUM(访问次数) DESC LIMIT 1",
|
||||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||||
}
|
}
|
||||||
|
|
||||||
private QueryContext buildQueryContext(Long dataSetId) {
|
private ChatQueryContext buildQueryContext(Long dataSetId) {
|
||||||
QueryContext queryContext = new QueryContext();
|
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||||
QueryConfig queryConfig = new QueryConfig();
|
QueryConfig queryConfig = new QueryConfig();
|
||||||
@@ -67,8 +67,8 @@ class AggCorrectorTest {
|
|||||||
dataSetSchemaList.add(dataSetSchema);
|
dataSetSchemaList.add(dataSetSchema);
|
||||||
|
|
||||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||||
queryContext.setSemanticSchema(semanticSchema);
|
chatQueryContext.setSemanticSchema(semanticSchema);
|
||||||
return queryContext;
|
return chatQueryContext;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||||
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
|
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
@@ -56,7 +56,7 @@ class SchemaCorrectorTest {
|
|||||||
@Test
|
@Test
|
||||||
void doCorrect() throws JsonProcessingException {
|
void doCorrect() throws JsonProcessingException {
|
||||||
Long dataSetId = 1L;
|
Long dataSetId = 1L;
|
||||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
|
||||||
ObjectMapper objectMapper = new ObjectMapper();
|
ObjectMapper objectMapper = new ObjectMapper();
|
||||||
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);
|
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);
|
||||||
|
|
||||||
@@ -77,7 +77,7 @@ class SchemaCorrectorTest {
|
|||||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||||
|
|
||||||
SchemaCorrector schemaCorrector = new SchemaCorrector();
|
SchemaCorrector schemaCorrector = new SchemaCorrector();
|
||||||
schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
|
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||||
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||||
@@ -94,14 +94,14 @@ class SchemaCorrectorTest {
|
|||||||
|
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
|
||||||
semanticParseInfo.getSqlInfo().setS2SQL(sql);
|
semanticParseInfo.getSqlInfo().setS2SQL(sql);
|
||||||
schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
|
schemaCorrector.removeFilterIfNotInLinkingValue(chatQueryContext, semanticParseInfo);
|
||||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||||
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private QueryContext buildQueryContext(Long dataSetId) {
|
private ChatQueryContext buildQueryContext(Long dataSetId) {
|
||||||
QueryContext queryContext = new QueryContext();
|
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||||
QueryConfig queryConfig = new QueryConfig();
|
QueryConfig queryConfig = new QueryConfig();
|
||||||
@@ -129,7 +129,7 @@ class SchemaCorrectorTest {
|
|||||||
dataSetSchemaList.add(dataSetSchema);
|
dataSetSchemaList.add(dataSetSchema);
|
||||||
|
|
||||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||||
queryContext.setSemanticSchema(semanticSchema);
|
chatQueryContext.setSemanticSchema(semanticSchema);
|
||||||
return queryContext;
|
return chatQueryContext;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
import org.mockito.MockedStatic;
|
import org.mockito.MockedStatic;
|
||||||
@@ -36,7 +36,7 @@ class SelectCorrectorTest {
|
|||||||
when(mockEnvironment.getProperty(SelectCorrector.ADDITIONAL_INFORMATION)).thenReturn("");
|
when(mockEnvironment.getProperty(SelectCorrector.ADDITIONAL_INFORMATION)).thenReturn("");
|
||||||
|
|
||||||
BaseSemanticCorrector corrector = new SelectCorrector();
|
BaseSemanticCorrector corrector = new SelectCorrector();
|
||||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
|
||||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
SchemaElement dataSet = new SchemaElement();
|
SchemaElement dataSet = new SchemaElement();
|
||||||
dataSet.setDataSet(dataSetId);
|
dataSet.setDataSet(dataSetId);
|
||||||
@@ -47,13 +47,13 @@ class SelectCorrectorTest {
|
|||||||
sqlInfo.setS2SQL(sql);
|
sqlInfo.setS2SQL(sql);
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||||
corrector.correct(queryContext, semanticParseInfo);
|
corrector.correct(chatQueryContext, semanticParseInfo);
|
||||||
Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
|
Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
|
||||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||||
}
|
}
|
||||||
|
|
||||||
private QueryContext buildQueryContext(Long dataSetId) {
|
private ChatQueryContext buildQueryContext(Long dataSetId) {
|
||||||
QueryContext queryContext = new QueryContext();
|
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||||
QueryConfig queryConfig = new QueryConfig();
|
QueryConfig queryConfig = new QueryConfig();
|
||||||
@@ -108,7 +108,7 @@ class SelectCorrectorTest {
|
|||||||
dataSetSchemaList.add(dataSetSchema);
|
dataSetSchemaList.add(dataSetSchema);
|
||||||
|
|
||||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||||
queryContext.setSemanticSchema(semanticSchema);
|
chatQueryContext.setSemanticSchema(semanticSchema);
|
||||||
return queryContext;
|
return chatQueryContext;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.corrector;
|
|||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.jupiter.api.Disabled;
|
import org.junit.jupiter.api.Disabled;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
@@ -13,7 +13,7 @@ class TimeCorrectorTest {
|
|||||||
@Test
|
@Test
|
||||||
void testDoCorrect() {
|
void testDoCorrect() {
|
||||||
TimeCorrector corrector = new TimeCorrector();
|
TimeCorrector corrector = new TimeCorrector();
|
||||||
QueryContext queryContext = new QueryContext();
|
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
SqlInfo sqlInfo = new SqlInfo();
|
SqlInfo sqlInfo = new SqlInfo();
|
||||||
//1.数据日期 <=
|
//1.数据日期 <=
|
||||||
@@ -21,7 +21,7 @@ class TimeCorrectorTest {
|
|||||||
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
|
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
Assert.assertEquals(
|
Assert.assertEquals(
|
||||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
|
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 <= '2023-11-17') "
|
||||||
@@ -32,7 +32,7 @@ class TimeCorrectorTest {
|
|||||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||||
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
|
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
Assert.assertEquals(
|
Assert.assertEquals(
|
||||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
|
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE ((歌手名 = '张三') AND 数据日期 < '2023-11-17') "
|
||||||
@@ -43,7 +43,7 @@ class TimeCorrectorTest {
|
|||||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||||
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
|
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
Assert.assertEquals(
|
Assert.assertEquals(
|
||||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||||
@@ -54,7 +54,7 @@ class TimeCorrectorTest {
|
|||||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||||
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
|
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
Assert.assertEquals(
|
Assert.assertEquals(
|
||||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||||
@@ -65,7 +65,7 @@ class TimeCorrectorTest {
|
|||||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||||
+ "WHERE 歌手名 = '张三' GROUP BY 维度1";
|
+ "WHERE 歌手名 = '张三' GROUP BY 维度1";
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
Assert.assertEquals(
|
Assert.assertEquals(
|
||||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
|
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
|
||||||
@@ -75,7 +75,7 @@ class TimeCorrectorTest {
|
|||||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
|
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
Assert.assertEquals(
|
Assert.assertEquals(
|
||||||
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
|
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE (歌手名 = '张三' AND 数据日期_月 <= '2024-01') "
|
||||||
@@ -86,7 +86,7 @@ class TimeCorrectorTest {
|
|||||||
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||||
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
|
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
Assert.assertEquals(
|
Assert.assertEquals(
|
||||||
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
"SELECT 维度1, SUM(播放量) FROM 数据库 "
|
||||||
@@ -96,7 +96,7 @@ class TimeCorrectorTest {
|
|||||||
//8. no where
|
//8. no where
|
||||||
sql = "SELECT COUNT(1) FROM 数据库";
|
sql = "SELECT COUNT(1) FROM 数据库";
|
||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
corrector.doCorrect(queryContext, semanticParseInfo);
|
corrector.doCorrect(chatQueryContext, semanticParseInfo);
|
||||||
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
|
Assert.assertEquals("SELECT COUNT(1) FROM 数据库", sqlInfo.getCorrectS2SQL());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.jupiter.api.Test;
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ class WhereCorrectorTest {
|
|||||||
sqlInfo.setCorrectS2SQL(sql);
|
sqlInfo.setCorrectS2SQL(sql);
|
||||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||||
|
|
||||||
QueryContext queryContext = new QueryContext();
|
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||||
|
|
||||||
QueryFilter filter1 = new QueryFilter();
|
QueryFilter filter1 = new QueryFilter();
|
||||||
filter1.setName("age");
|
filter1.setName("age");
|
||||||
@@ -49,10 +49,10 @@ class WhereCorrectorTest {
|
|||||||
queryFilters.getFilters().add(filter2);
|
queryFilters.getFilters().add(filter2);
|
||||||
queryFilters.getFilters().add(filter3);
|
queryFilters.getFilters().add(filter3);
|
||||||
queryFilters.getFilters().add(filter4);
|
queryFilters.getFilters().add(filter4);
|
||||||
queryContext.setQueryFilters(queryFilters);
|
chatQueryContext.setQueryFilters(queryFilters);
|
||||||
|
|
||||||
WhereCorrector whereCorrector = new WhereCorrector();
|
WhereCorrector whereCorrector = new WhereCorrector();
|
||||||
whereCorrector.addQueryFilter(queryContext, semanticParseInfo);
|
whereCorrector.addQueryFilter(chatQueryContext, semanticParseInfo);
|
||||||
|
|
||||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.corrector.S2SqlDateHelper;
|
import com.tencent.supersonic.headless.chat.corrector.S2SqlDateHelper;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
@@ -26,15 +26,15 @@ class S2SqlDateHelperTest {
|
|||||||
@Test
|
@Test
|
||||||
void getReferenceDate() {
|
void getReferenceDate() {
|
||||||
Long dataSetId = 1L;
|
Long dataSetId = 1L;
|
||||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
|
||||||
|
|
||||||
String referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, null);
|
String referenceDate = S2SqlDateHelper.getReferenceDate(chatQueryContext, null);
|
||||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
|
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
|
||||||
|
|
||||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
referenceDate = S2SqlDateHelper.getReferenceDate(chatQueryContext, dataSetId);
|
||||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
|
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
|
||||||
|
|
||||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||||
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
|
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
|
||||||
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
||||||
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||||
@@ -42,32 +42,32 @@ class S2SqlDateHelperTest {
|
|||||||
timeDefaultConfig.setUnit(20);
|
timeDefaultConfig.setUnit(20);
|
||||||
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
||||||
|
|
||||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
referenceDate = S2SqlDateHelper.getReferenceDate(chatQueryContext, dataSetId);
|
||||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(20));
|
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(20));
|
||||||
|
|
||||||
timeDefaultConfig.setUnit(1);
|
timeDefaultConfig.setUnit(1);
|
||||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
referenceDate = S2SqlDateHelper.getReferenceDate(chatQueryContext, dataSetId);
|
||||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(1));
|
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(1));
|
||||||
|
|
||||||
timeDefaultConfig.setUnit(-1);
|
timeDefaultConfig.setUnit(-1);
|
||||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
referenceDate = S2SqlDateHelper.getReferenceDate(chatQueryContext, dataSetId);
|
||||||
Assert.assertNull(referenceDate);
|
Assert.assertNull(referenceDate);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
void getStartEndDate() {
|
void getStartEndDate() {
|
||||||
Long dataSetId = 1L;
|
Long dataSetId = 1L;
|
||||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
ChatQueryContext chatQueryContext = buildQueryContext(dataSetId);
|
||||||
|
|
||||||
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, null, QueryType.DETAIL);
|
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, null, QueryType.DETAIL);
|
||||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(0));
|
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(0));
|
||||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(0));
|
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(0));
|
||||||
|
|
||||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL);
|
startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.DETAIL);
|
||||||
Assert.assertNotNull(startEndDate.getLeft());
|
Assert.assertNotNull(startEndDate.getLeft());
|
||||||
Assert.assertNotNull(startEndDate.getRight());
|
Assert.assertNotNull(startEndDate.getRight());
|
||||||
|
|
||||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||||
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
|
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
|
||||||
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
||||||
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||||
@@ -76,39 +76,39 @@ class S2SqlDateHelperTest {
|
|||||||
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
||||||
queryConfig.getMetricTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
queryConfig.getMetricTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
||||||
|
|
||||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL);
|
startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.DETAIL);
|
||||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
|
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
|
||||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
|
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
|
||||||
|
|
||||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.METRIC);
|
||||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
|
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
|
||||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
|
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
|
||||||
|
|
||||||
timeDefaultConfig.setUnit(2);
|
timeDefaultConfig.setUnit(2);
|
||||||
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
|
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
|
||||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.METRIC);
|
||||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
|
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
|
||||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
|
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
|
||||||
|
|
||||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.DETAIL);
|
startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.DETAIL);
|
||||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
|
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
|
||||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
|
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
|
||||||
|
|
||||||
timeDefaultConfig.setUnit(-1);
|
timeDefaultConfig.setUnit(-1);
|
||||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.METRIC);
|
||||||
Assert.assertNull(startEndDate.getLeft());
|
Assert.assertNull(startEndDate.getLeft());
|
||||||
Assert.assertNull(startEndDate.getRight());
|
Assert.assertNull(startEndDate.getRight());
|
||||||
|
|
||||||
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||||
timeDefaultConfig.setPeriod(Constants.DAY);
|
timeDefaultConfig.setPeriod(Constants.DAY);
|
||||||
timeDefaultConfig.setUnit(5);
|
timeDefaultConfig.setUnit(5);
|
||||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId, QueryType.METRIC);
|
||||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(5));
|
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(5));
|
||||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(5));
|
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(5));
|
||||||
}
|
}
|
||||||
|
|
||||||
private QueryContext buildQueryContext(Long dataSetId) {
|
private ChatQueryContext buildQueryContext(Long dataSetId) {
|
||||||
QueryContext queryContext = new QueryContext();
|
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||||
QueryConfig queryConfig = new QueryConfig();
|
QueryConfig queryConfig = new QueryConfig();
|
||||||
@@ -119,7 +119,7 @@ class S2SqlDateHelperTest {
|
|||||||
dataSetSchemaList.add(dataSetSchema);
|
dataSetSchemaList.add(dataSetSchema);
|
||||||
|
|
||||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||||
queryContext.setSemanticSchema(semanticSchema);
|
chatQueryContext.setSemanticSchema(semanticSchema);
|
||||||
return queryContext;
|
return chatQueryContext;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -49,7 +49,7 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
|
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
|
||||||
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
|
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
||||||
@@ -63,7 +63,7 @@ import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
|||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||||
import com.tencent.supersonic.headless.server.utils.WorkflowEngine;
|
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
|
||||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
||||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||||
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
|
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
|
||||||
@@ -115,12 +115,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private DataSetService dataSetService;
|
private DataSetService dataSetService;
|
||||||
@Autowired
|
@Autowired
|
||||||
private WorkflowEngine workflowEngine;
|
private ChatWorkflowEngine chatWorkflowEngine;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MapResp performMapping(QueryTextReq queryTextReq) {
|
public MapResp performMapping(QueryTextReq queryTextReq) {
|
||||||
MapResp mapResp = new MapResp();
|
MapResp mapResp = new MapResp();
|
||||||
QueryContext queryCtx = buildQueryContext(queryTextReq);
|
ChatQueryContext queryCtx = buildQueryContext(queryTextReq);
|
||||||
ComponentFactory.getSchemaMappers().forEach(mapper -> {
|
ComponentFactory.getSchemaMappers().forEach(mapper -> {
|
||||||
mapper.map(queryCtx);
|
mapper.map(queryCtx);
|
||||||
});
|
});
|
||||||
@@ -148,12 +148,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
public ParseResp performParsing(QueryTextReq queryTextReq) {
|
public ParseResp performParsing(QueryTextReq queryTextReq) {
|
||||||
ParseResp parseResult = new ParseResp(queryTextReq.getChatId(), queryTextReq.getQueryText());
|
ParseResp parseResult = new ParseResp(queryTextReq.getChatId(), queryTextReq.getQueryText());
|
||||||
// build queryContext and chatContext
|
// build queryContext and chatContext
|
||||||
QueryContext queryCtx = buildQueryContext(queryTextReq);
|
ChatQueryContext queryCtx = buildQueryContext(queryTextReq);
|
||||||
|
|
||||||
// in order to support multi-turn conversation, chat context is needed
|
// in order to support multi-turn conversation, chat context is needed
|
||||||
ChatContext chatCtx = chatContextService.getOrCreateContext(queryTextReq.getChatId());
|
ChatContext chatCtx = chatContextService.getOrCreateContext(queryTextReq.getChatId());
|
||||||
|
|
||||||
workflowEngine.startWorkflow(queryCtx, chatCtx, parseResult);
|
chatWorkflowEngine.execute(queryCtx, chatCtx, parseResult);
|
||||||
|
|
||||||
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
|
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
|
||||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||||
@@ -161,11 +161,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
return parseResult;
|
return parseResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
public QueryContext buildQueryContext(QueryTextReq queryTextReq) {
|
public ChatQueryContext buildQueryContext(QueryTextReq queryTextReq) {
|
||||||
|
|
||||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||||
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds();
|
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds();
|
||||||
QueryContext queryCtx = QueryContext.builder()
|
ChatQueryContext queryCtx = ChatQueryContext.builder()
|
||||||
.queryFilters(queryTextReq.getQueryFilters())
|
.queryFilters(queryTextReq.getQueryFilters())
|
||||||
.semanticSchema(semanticSchema)
|
.semanticSchema(semanticSchema)
|
||||||
.candidateQueries(new ArrayList<>())
|
.candidateQueries(new ArrayList<>())
|
||||||
@@ -612,7 +612,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private SemanticParseInfo correctSqlReq(QuerySqlReq querySqlReq, User user) {
|
private SemanticParseInfo correctSqlReq(QuerySqlReq querySqlReq, User user) {
|
||||||
QueryContext queryCtx = new QueryContext();
|
ChatQueryContext queryCtx = new ChatQueryContext();
|
||||||
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
|
||||||
queryCtx.setSemanticSchema(semanticSchema);
|
queryCtx.setSemanticSchema(semanticSchema);
|
||||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.DataSetInfoStat;
|
import com.tencent.supersonic.headless.chat.knowledge.DataSetInfoStat;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
||||||
@@ -78,12 +78,12 @@ public class RetrieveServiceImpl implements RetrieveService {
|
|||||||
log.debug("hanlp parse result: {}", originals);
|
log.debug("hanlp parse result: {}", originals);
|
||||||
Set<Long> dataSetIds = queryTextReq.getDataSetIds();
|
Set<Long> dataSetIds = queryTextReq.getDataSetIds();
|
||||||
|
|
||||||
QueryContext queryContext = new QueryContext();
|
ChatQueryContext chatQueryContext = new ChatQueryContext();
|
||||||
BeanUtils.copyProperties(queryTextReq, queryContext);
|
BeanUtils.copyProperties(queryTextReq, chatQueryContext);
|
||||||
queryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds());
|
chatQueryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds());
|
||||||
|
|
||||||
Map<MatchText, List<HanlpMapResult>> regTextMap =
|
Map<MatchText, List<HanlpMapResult>> regTextMap =
|
||||||
searchMatchStrategy.match(queryContext, originals, dataSetIds);
|
searchMatchStrategy.match(chatQueryContext, originals, dataSetIds);
|
||||||
|
|
||||||
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
|
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
import com.tencent.supersonic.headless.server.web.service.SchemaService;
|
import com.tencent.supersonic.headless.server.web.service.SchemaService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -40,8 +40,8 @@ import java.util.stream.Collectors;
|
|||||||
public class ParseInfoProcessor implements ResultProcessor {
|
public class ParseInfoProcessor implements ResultProcessor {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
|
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||||
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
|
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
|
||||||
if (CollectionUtils.isEmpty(candidateQueries)) {
|
if (CollectionUtils.isEmpty(candidateQueries)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,13 +2,13 @@ package com.tencent.supersonic.headless.server.processor;
|
|||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A ParseResultProcessor wraps things up before returning results to users in parse stage.
|
* A ParseResultProcessor wraps things up before returning results to users in parse stage.
|
||||||
*/
|
*/
|
||||||
public interface ResultProcessor {
|
public interface ResultProcessor {
|
||||||
|
|
||||||
void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext);
|
void process(ParseResp parseResp, ChatQueryContext chatQueryContext, ChatContext chatContext);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,88 +0,0 @@
|
|||||||
package com.tencent.supersonic.headless.server.processor;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
|
||||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
|
||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
|
||||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* SqlInfoProcessor adds intermediate S2SQL and final SQL to the parsing results
|
|
||||||
* so that technical users could verify SQL by themselves.
|
|
||||||
**/
|
|
||||||
@Slf4j
|
|
||||||
public class SqlInfoProcessor implements ResultProcessor {
|
|
||||||
|
|
||||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
|
|
||||||
long start = System.currentTimeMillis();
|
|
||||||
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
|
|
||||||
if (CollectionUtils.isEmpty(semanticQueries)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
addSqlInfo(queryContext, selectedParses);
|
|
||||||
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void addSqlInfo(QueryContext queryContext, List<SemanticParseInfo> semanticParseInfos) {
|
|
||||||
if (CollectionUtils.isEmpty(semanticParseInfos)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
semanticParseInfos.forEach(parseInfo -> {
|
|
||||||
try {
|
|
||||||
addSqlInfo(queryContext, parseInfo);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.warn("get sql info failed:{}", parseInfo, e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private void addSqlInfo(QueryContext queryContext, SemanticParseInfo parseInfo) throws Exception {
|
|
||||||
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
|
|
||||||
if (Objects.isNull(semanticQuery)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
semanticQuery.setParseInfo(parseInfo);
|
|
||||||
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
|
||||||
SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class);
|
|
||||||
TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq)
|
|
||||||
.queryTypeEnum(QueryMethod.SQL).build();
|
|
||||||
TranslateResp explain = queryService.translate(translateSqlReq, queryContext.getUser());
|
|
||||||
String querySql = explain.getSql();
|
|
||||||
if (StringUtils.isBlank(querySql)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
|
||||||
if (semanticQuery instanceof LLMSqlQuery) {
|
|
||||||
keyPipelineLog.info("SqlInfoProcessor results:\n"
|
|
||||||
+ "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}",
|
|
||||||
StringUtils.normalizeSpace(sqlInfo.getS2SQL()),
|
|
||||||
StringUtils.normalizeSpace(sqlInfo.getCorrectS2SQL()),
|
|
||||||
StringUtils.normalizeSpace(querySql));
|
|
||||||
}
|
|
||||||
sqlInfo.setQuerySQL(querySql);
|
|
||||||
sqlInfo.setSourceId(explain.getSourceId());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,148 @@
|
|||||||
|
package com.tencent.supersonic.headless.server.utils;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
|
||||||
|
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||||
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
|
||||||
|
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
|
||||||
|
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||||
|
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||||
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
|
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||||
|
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||||
|
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.apache.commons.collections.MapUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public class ChatWorkflowEngine {
|
||||||
|
|
||||||
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
|
||||||
|
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
|
||||||
|
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
|
||||||
|
private List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
|
||||||
|
|
||||||
|
public void execute(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
|
||||||
|
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
|
||||||
|
while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) {
|
||||||
|
switch (queryCtx.getChatWorkflowState()) {
|
||||||
|
case MAPPING:
|
||||||
|
performMapping(queryCtx);
|
||||||
|
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING);
|
||||||
|
break;
|
||||||
|
case PARSING:
|
||||||
|
performParsing(queryCtx, chatCtx);
|
||||||
|
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
|
||||||
|
break;
|
||||||
|
case CORRECTING:
|
||||||
|
performCorrecting(queryCtx);
|
||||||
|
queryCtx.setChatWorkflowState(ChatWorkflowState.TRANSLATING);
|
||||||
|
break;
|
||||||
|
case TRANSLATING:
|
||||||
|
long start = System.currentTimeMillis();
|
||||||
|
performTranslating(queryCtx);
|
||||||
|
parseResult.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start);
|
||||||
|
queryCtx.setChatWorkflowState(ChatWorkflowState.PROCESSING);
|
||||||
|
break;
|
||||||
|
case PROCESSING:
|
||||||
|
default:
|
||||||
|
performProcessing(queryCtx, chatCtx, parseResult);
|
||||||
|
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void performMapping(ChatQueryContext queryCtx) {
|
||||||
|
if (Objects.isNull(queryCtx.getMapInfo())
|
||||||
|
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) {
|
||||||
|
schemaMappers.forEach(mapper -> mapper.map(queryCtx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) {
|
||||||
|
semanticParsers.forEach(parser -> {
|
||||||
|
parser.parse(queryCtx, chatCtx);
|
||||||
|
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public void performCorrecting(ChatQueryContext queryCtx) {
|
||||||
|
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
|
||||||
|
if (CollectionUtils.isNotEmpty(candidateQueries)) {
|
||||||
|
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||||
|
if (semanticQuery instanceof RuleSemanticQuery) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (SemanticCorrector corrector : semanticCorrectors) {
|
||||||
|
corrector.correct(queryCtx, semanticQuery.getParseInfo());
|
||||||
|
if (!ChatWorkflowState.CORRECTING.equals(queryCtx.getChatWorkflowState())) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
|
||||||
|
resultProcessors.forEach(processor -> {
|
||||||
|
processor.process(parseResult, queryCtx, chatCtx);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private void performTranslating(ChatQueryContext chatQueryContext) {
|
||||||
|
List<SemanticParseInfo> semanticParseInfos = chatQueryContext.getCandidateQueries().stream()
|
||||||
|
.map(SemanticQuery::getParseInfo)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
|
||||||
|
semanticParseInfos.forEach(parseInfo -> {
|
||||||
|
try {
|
||||||
|
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
|
||||||
|
if (Objects.isNull(semanticQuery)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
semanticQuery.setParseInfo(parseInfo);
|
||||||
|
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
|
||||||
|
SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class);
|
||||||
|
TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq)
|
||||||
|
.queryTypeEnum(QueryMethod.SQL).build();
|
||||||
|
TranslateResp explain = queryService.translate(translateSqlReq, chatQueryContext.getUser());
|
||||||
|
String querySql = explain.getSql();
|
||||||
|
if (StringUtils.isBlank(querySql)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||||
|
sqlInfo.setQuerySQL(querySql);
|
||||||
|
sqlInfo.setSourceId(explain.getSourceId());
|
||||||
|
|
||||||
|
keyPipelineLog.info("SqlInfoProcessor results:\n"
|
||||||
|
+ "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}",
|
||||||
|
StringUtils.normalizeSpace(sqlInfo.getS2SQL()),
|
||||||
|
StringUtils.normalizeSpace(sqlInfo.getCorrectS2SQL()),
|
||||||
|
StringUtils.normalizeSpace(querySql));
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("get sql info failed:{}", parseInfo, e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
package com.tencent.supersonic.headless.server.utils;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|
||||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
|
||||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
|
||||||
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
|
|
||||||
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
|
|
||||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
|
||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
|
||||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
|
||||||
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
import org.apache.commons.collections.MapUtils;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Objects;
|
|
||||||
|
|
||||||
@Service
|
|
||||||
@Slf4j
|
|
||||||
public class WorkflowEngine {
|
|
||||||
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
|
|
||||||
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
|
|
||||||
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
|
|
||||||
private List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
|
|
||||||
|
|
||||||
public void startWorkflow(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
|
|
||||||
queryCtx.setWorkflowState(WorkflowState.MAPPING);
|
|
||||||
while (queryCtx.getWorkflowState() != WorkflowState.FINISHED) {
|
|
||||||
switch (queryCtx.getWorkflowState()) {
|
|
||||||
case MAPPING:
|
|
||||||
performMapping(queryCtx);
|
|
||||||
queryCtx.setWorkflowState(WorkflowState.PARSING);
|
|
||||||
break;
|
|
||||||
case PARSING:
|
|
||||||
performParsing(queryCtx, chatCtx);
|
|
||||||
queryCtx.setWorkflowState(WorkflowState.CORRECTING);
|
|
||||||
break;
|
|
||||||
case CORRECTING:
|
|
||||||
performCorrecting(queryCtx);
|
|
||||||
queryCtx.setWorkflowState(WorkflowState.PROCESSING);
|
|
||||||
break;
|
|
||||||
case PROCESSING:
|
|
||||||
default:
|
|
||||||
performProcessing(queryCtx, chatCtx, parseResult);
|
|
||||||
queryCtx.setWorkflowState(WorkflowState.FINISHED);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void performMapping(QueryContext queryCtx) {
|
|
||||||
if (Objects.isNull(queryCtx.getMapInfo())
|
|
||||||
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) {
|
|
||||||
schemaMappers.forEach(mapper -> mapper.map(queryCtx));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void performParsing(QueryContext queryCtx, ChatContext chatCtx) {
|
|
||||||
semanticParsers.forEach(parser -> {
|
|
||||||
parser.parse(queryCtx, chatCtx);
|
|
||||||
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
public void performCorrecting(QueryContext queryCtx) {
|
|
||||||
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
|
|
||||||
if (CollectionUtils.isNotEmpty(candidateQueries)) {
|
|
||||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
|
||||||
if (semanticQuery instanceof RuleSemanticQuery) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (SemanticCorrector corrector : semanticCorrectors) {
|
|
||||||
corrector.correct(queryCtx, semanticQuery.getParseInfo());
|
|
||||||
if (!WorkflowState.CORRECTING.equals(queryCtx.getWorkflowState())) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public void performProcessing(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
|
|
||||||
resultProcessors.forEach(processor -> {
|
|
||||||
processor.process(parseResult, queryCtx, chatCtx);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -47,9 +47,7 @@ com.tencent.supersonic.headless.core.cache.QueryCache=\
|
|||||||
### headless-server SPIs
|
### headless-server SPIs
|
||||||
|
|
||||||
com.tencent.supersonic.headless.server.processor.ResultProcessor=\
|
com.tencent.supersonic.headless.server.processor.ResultProcessor=\
|
||||||
com.tencent.supersonic.headless.server.processor.ParseInfoProcessor, \
|
com.tencent.supersonic.headless.server.processor.ParseInfoProcessor
|
||||||
com.tencent.supersonic.headless.server.processor.SqlInfoProcessor
|
|
||||||
|
|
||||||
|
|
||||||
### chat-server SPIs
|
### chat-server SPIs
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user