mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 22:25:19 +00:00
(improvement)(Headless)(Chat) Change View to DataSet (#782)
* (improvement)(Headless)(Chat) Change view to dataSet --------- Co-authored-by: jolunoluo <jolunoluo@tencent.com>
This commit is contained in:
@@ -65,16 +65,16 @@ public class Agent extends RecordInfo {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public Set<Long> getViewIds() {
|
||||
return getViewIds(null);
|
||||
public Set<Long> getDataSetIds() {
|
||||
return getDataSetIds(null);
|
||||
}
|
||||
|
||||
public Set<Long> getViewIds(AgentToolType agentToolType) {
|
||||
public Set<Long> getDataSetIds(AgentToolType agentToolType) {
|
||||
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
|
||||
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
return commonAgentTools.stream().map(NL2SQLTool::getViewIds)
|
||||
return commonAgentTools.stream().map(NL2SQLTool::getDataSetIds)
|
||||
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
@@ -12,6 +12,6 @@ import java.util.List;
|
||||
@AllArgsConstructor
|
||||
public class NL2SQLTool extends AgentTool {
|
||||
|
||||
protected List<Long> viewIds;
|
||||
protected List<Long> dataSetIds;
|
||||
|
||||
}
|
||||
@@ -15,7 +15,7 @@ public class RuleParserTool extends NL2SQLTool {
|
||||
private List<String> queryTypes;
|
||||
|
||||
public boolean isContainsAllModel() {
|
||||
return CollectionUtils.isNotEmpty(viewIds) && viewIds.contains(-1L);
|
||||
return CollectionUtils.isNotEmpty(dataSetIds) && dataSetIds.contains(-1L);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
||||
|
||||
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long viewId) {
|
||||
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long dataSetId) {
|
||||
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
@@ -55,7 +55,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
// support fieldName and field alias
|
||||
Map<String, String> result = dbAllFields.stream()
|
||||
.filter(entry -> viewId.equals(entry.getView()))
|
||||
.filter(entry -> dataSetId.equals(entry.getDataSet()))
|
||||
.flatMap(schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
@@ -109,8 +109,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
Long viewId = semanticParseInfo.getView().getView();
|
||||
List<SchemaElement> metrics = getMetricElements(queryContext, viewId);
|
||||
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
|
||||
List<SchemaElement> metrics = getMetricElements(queryContext, dataSetId);
|
||||
|
||||
Map<String, String> metricToAggregate = metrics.stream()
|
||||
.map(schemaElement -> {
|
||||
@@ -135,13 +135,13 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long viewId) {
|
||||
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
return semanticSchema.getMetrics(viewId);
|
||||
return semanticSchema.getMetrics(dataSetId);
|
||||
}
|
||||
|
||||
protected Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
|
||||
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
|
||||
protected Set<String> getDimensions(Long dataSetId, SemanticSchema semanticSchema) {
|
||||
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
|
||||
.flatMap(
|
||||
schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
|
||||
@@ -8,12 +8,13 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.Dim;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import com.tencent.supersonic.headless.server.service.ViewService;
|
||||
import com.tencent.supersonic.headless.server.service.DataSetService;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
@@ -37,11 +38,12 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Long viewId = semanticParseInfo.getViewId();
|
||||
ViewService viewService = ContextUtils.getBean(ViewService.class);
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
DataSetService dataSetService = ContextUtils.getBean(DataSetService.class);
|
||||
ModelService modelService = ContextUtils.getBean(ModelService.class);
|
||||
ViewResp viewResp = viewService.getView(viewId);
|
||||
List<Long> modelIds = viewResp.getViewDetail().getViewModelConfigs().stream().map(config -> config.getId())
|
||||
DataSetResp dataSetResp = dataSetService.getDataSet(dataSetId);
|
||||
List<Long> modelIds = dataSetResp.getDataSetDetail()
|
||||
.getDataSetModelConfigs().stream().map(DataSetModelConfig::getId)
|
||||
.collect(Collectors.toList());
|
||||
MetaFilter metaFilter = new MetaFilter();
|
||||
metaFilter.setIds(modelIds);
|
||||
@@ -64,7 +66,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
return false;
|
||||
}
|
||||
//add alias field name
|
||||
Set<String> dimensions = getDimensions(viewId, semanticSchema);
|
||||
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
return false;
|
||||
@@ -81,13 +83,13 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Long viewId = semanticParseInfo.getViewId();
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
//add alias field name
|
||||
Set<String> dimensions = getDimensions(viewId, semanticSchema);
|
||||
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
Set<String> groupByFields = selectFields.stream()
|
||||
|
||||
@@ -39,11 +39,11 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Long viewId = semanticParseInfo.getView().getView();
|
||||
Long dataSet = semanticParseInfo.getDataSet().getDataSet();
|
||||
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
|
||||
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
|
||||
@@ -16,15 +16,16 @@ import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform schema corrections on the Schema information in S2SQL.
|
||||
@@ -62,7 +63,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getViewId());
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getDataSetId());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
@@ -125,7 +126,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
List<ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
Set<String> dimensions = getDimensions(semanticParseInfo.getViewId(), semanticSchema);
|
||||
Set<String> dimensions = getDimensions(semanticParseInfo.getDataSetId(), semanticSchema);
|
||||
|
||||
if (CollectionUtils.isEmpty(linkingValues)) {
|
||||
linkingValues = new ArrayList<>();
|
||||
|
||||
@@ -67,7 +67,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
|
||||
semanticParseInfo.getViewId(), semanticParseInfo.getQueryType());
|
||||
semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType());
|
||||
if (StringUtils.isNotBlank(startEndDate.getLeft())
|
||||
&& StringUtils.isNotBlank(startEndDate.getRight())) {
|
||||
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
@@ -101,8 +101,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
Long viewId = semanticParseInfo.getViewId();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(viewId);
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
|
||||
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
|
||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -26,7 +26,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
|
||||
String simpleName = this.getClass().getSimpleName();
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getViewElementMatches());
|
||||
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getDataSetElementMatches());
|
||||
|
||||
try {
|
||||
doMap(queryContext);
|
||||
@@ -35,13 +35,14 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
}
|
||||
|
||||
long cost = System.currentTimeMillis() - startTime;
|
||||
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getViewElementMatches());
|
||||
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost,
|
||||
queryContext.getMapInfo().getDataSetElementMatches());
|
||||
}
|
||||
|
||||
public abstract void doMap(QueryContext queryContext);
|
||||
|
||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getViewElementMatches();
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getDataSetElementMatches();
|
||||
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = modelElementMatches.get(modelId);
|
||||
@@ -67,14 +68,14 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getSchemaElement(Long viewId, SchemaElementType elementType, Long elementID,
|
||||
public SchemaElement getSchemaElement(Long dataSetId, SchemaElementType elementType, Long elementID,
|
||||
SemanticSchema semanticSchema) {
|
||||
SchemaElement element = new SchemaElement();
|
||||
ViewSchema viewSchema = semanticSchema.getViewSchemaMap().get(viewId);
|
||||
if (Objects.isNull(viewSchema)) {
|
||||
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
||||
if (Objects.isNull(dataSetSchema)) {
|
||||
return null;
|
||||
}
|
||||
SchemaElement elementDb = viewSchema.getElement(elementType, elementID);
|
||||
SchemaElement elementDb = dataSetSchema.getElement(elementType, elementID);
|
||||
if (Objects.isNull(elementDb)) {
|
||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||
return null;
|
||||
|
||||
@@ -28,22 +28,22 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectViewIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("terms:{},,detectViewIds:{}", terms, detectViewIds);
|
||||
log.debug("terms:{},,detectDataSetIds:{}", terms, detectDataSetIds);
|
||||
|
||||
List<T> detects = detect(queryContext, terms, detectViewIds);
|
||||
List<T> detects = detect(queryContext, terms, detectDataSetIds);
|
||||
Map<MatchText, List<T>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds) {
|
||||
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||
String text = queryContext.getQueryText();
|
||||
Set<T> results = new HashSet<>();
|
||||
@@ -58,16 +58,16 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
if (index <= text.length()) {
|
||||
String detectSegment = text.substring(startIndex, index).trim();
|
||||
detectSegments.add(detectSegment);
|
||||
detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
|
||||
detectByStep(queryContext, results, detectDataSetIds, detectSegment, offset);
|
||||
}
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
}
|
||||
detectByBatch(queryContext, results, detectViewIds, detectSegments);
|
||||
detectByBatch(queryContext, results, detectDataSetIds, detectSegments);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectViewIds,
|
||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectDataSetIds,
|
||||
Set<String> detectSegments) {
|
||||
return;
|
||||
}
|
||||
@@ -104,9 +104,9 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
}
|
||||
|
||||
public List<T> getMatches(QueryContext queryContext, List<S2Term> terms) {
|
||||
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
|
||||
terms = filterByViewId(terms, viewIds);
|
||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, viewIds);
|
||||
Set<Long> dataSetIds = mapperHelper.getDataSetIds(queryContext.getDataSetId(), queryContext.getAgent());
|
||||
terms = filterByDataSetId(terms, dataSetIds);
|
||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, dataSetIds);
|
||||
List<T> matches = new ArrayList<>();
|
||||
if (Objects.isNull(matchResult)) {
|
||||
return matches;
|
||||
@@ -121,17 +121,17 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
return matches;
|
||||
}
|
||||
|
||||
public List<S2Term> filterByViewId(List<S2Term> terms, Set<Long> viewIds) {
|
||||
public List<S2Term> filterByDataSetId(List<S2Term> terms, Set<Long> dataSetIds) {
|
||||
logTerms(terms);
|
||||
if (CollectionUtils.isNotEmpty(viewIds)) {
|
||||
if (CollectionUtils.isNotEmpty(dataSetIds)) {
|
||||
terms = terms.stream().filter(term -> {
|
||||
Long viewId = NatureHelper.getViewId(term.getNature().toString());
|
||||
if (Objects.nonNull(viewId)) {
|
||||
return viewIds.contains(viewId);
|
||||
Long dataSetId = NatureHelper.getDataSetId(term.getNature().toString());
|
||||
if (Objects.nonNull(dataSetId)) {
|
||||
return dataSetIds.contains(dataSetId);
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
log.info("terms filter by viewId:{}", viewIds);
|
||||
log.info("terms filter by dataSetId:{}", dataSetIds);
|
||||
logTerms(terms);
|
||||
}
|
||||
return terms;
|
||||
@@ -150,7 +150,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
public abstract String getMapKey(T a);
|
||||
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset);
|
||||
|
||||
}
|
||||
|
||||
@@ -37,9 +37,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectViewIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
this.allElements = getSchemaElements(queryContext);
|
||||
return super.match(queryContext, terms, detectViewIds);
|
||||
return super.match(queryContext, terms, detectDataSetIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -54,7 +54,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
@@ -70,9 +70,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
continue;
|
||||
}
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(detectViewIds)) {
|
||||
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
|
||||
.filter(schemaElement -> detectDataSetIds.contains(schemaElement.getDataSet()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
@@ -96,7 +96,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getViewElementMatches();
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches();
|
||||
|
||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
|
||||
@@ -36,12 +36,12 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
//2. build SchemaElementMatch by info
|
||||
for (EmbeddingResult matchResult : matchResults) {
|
||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
|
||||
if (Objects.isNull(viewId)) {
|
||||
Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId"));
|
||||
if (Objects.isNull(dataSetId)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
|
||||
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
|
||||
queryContext.getSemanticSchema());
|
||||
if (schemaElement == null) {
|
||||
continue;
|
||||
@@ -54,7 +54,7 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
.detectWord(matchResult.getDetectWord())
|
||||
.build();
|
||||
//3. add to mapInfo
|
||||
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,13 +48,13 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||
Set<String> detectSegments) {
|
||||
|
||||
List<String> queryTextsList = detectSegments.stream()
|
||||
@@ -68,11 +68,11 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
optimizationConfig.getEmbeddingMapperBatch());
|
||||
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
detectByQueryTextsSub(results, detectViewIds, queryTextsSub);
|
||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub);
|
||||
}
|
||||
}
|
||||
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||
List<String> queryTextsSub) {
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||
@@ -80,7 +80,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
// step2. retrieveQuery by detectSegment
|
||||
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||
new ArrayList<>(detectViewIds), retrieveQuery, embeddingNumber);
|
||||
new ArrayList<>(detectDataSetIds), retrieveQuery, embeddingNumber);
|
||||
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -24,12 +24,12 @@ public class EntityMapper extends BaseMapper {
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
for (Long viewId : schemaMapInfo.getMatchedViewInfos()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(viewId);
|
||||
for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElement entity = getEntity(viewId, queryContext);
|
||||
SchemaElement entity = getEntity(dataSetId, queryContext);
|
||||
if (entity == null || entity.getId() == null) {
|
||||
continue;
|
||||
}
|
||||
@@ -65,9 +65,9 @@ public class EntityMapper extends BaseMapper {
|
||||
return false;
|
||||
}
|
||||
|
||||
private SchemaElement getEntity(Long viewId, QueryContext queryContext) {
|
||||
private SchemaElement getEntity(Long dataSetId, QueryContext queryContext) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
ViewSchema modelSchema = semanticSchema.getViewSchemaMap().get(viewId);
|
||||
DataSetSchema modelSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||
return modelSchema.getEntity();
|
||||
}
|
||||
|
||||
@@ -39,15 +39,15 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectViewIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectViewIds);
|
||||
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectDataSetIds);
|
||||
|
||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectViewIds);
|
||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectDataSetIds);
|
||||
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
@@ -60,15 +60,15 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
oneDetectionMaxSize, detectDataSetIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
|
||||
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
oneDetectionMaxSize, detectDataSetIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
|
||||
|
||||
@@ -59,8 +59,8 @@ public class KeywordMapper extends BaseMapper {
|
||||
|
||||
for (HanlpMapResult hanlpMapResult : mapResults) {
|
||||
for (String nature : hanlpMapResult.getNatures()) {
|
||||
Long viewId = NatureHelper.getViewId(nature);
|
||||
if (Objects.isNull(viewId)) {
|
||||
Long dataSetId = NatureHelper.getDataSetId(nature);
|
||||
if (Objects.isNull(dataSetId)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
|
||||
@@ -68,7 +68,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
continue;
|
||||
}
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
SchemaElement element = getSchemaElement(viewId, elementType,
|
||||
SchemaElement element = getSchemaElement(dataSetId, elementType,
|
||||
elementID, queryContext.getSemanticSchema());
|
||||
if (element == null) {
|
||||
continue;
|
||||
@@ -85,7 +85,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
.detectWord(hanlpMapResult.getDetectWord())
|
||||
.build();
|
||||
|
||||
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -106,12 +106,12 @@ public class KeywordMapper extends BaseMapper {
|
||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getView(), schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getDataSet(), schemaElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getView());
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDataSet());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ public class MapperHelper {
|
||||
*/
|
||||
public boolean existDimensionValues(List<String> natures) {
|
||||
for (String nature : natures) {
|
||||
if (NatureHelper.isDimensionValueViewId(nature)) {
|
||||
if (NatureHelper.isDimensionValueDataSetId(nature)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -82,33 +82,33 @@ public class MapperHelper {
|
||||
detectSegment.length());
|
||||
}
|
||||
|
||||
public Set<Long> getViewIds(Long viewId, Agent agent) {
|
||||
public Set<Long> getDataSetIds(Long dataSetId, Agent agent) {
|
||||
|
||||
Set<Long> detectViewIds = new HashSet<>();
|
||||
Set<Long> detectDataSetIds = new HashSet<>();
|
||||
if (Objects.nonNull(agent)) {
|
||||
detectViewIds = agent.getViewIds(null);
|
||||
detectDataSetIds = agent.getDataSetIds();
|
||||
}
|
||||
//contains all
|
||||
if (Agent.containsAllModel(detectViewIds)) {
|
||||
if (Objects.nonNull(viewId) && viewId > 0) {
|
||||
if (Agent.containsAllModel(detectDataSetIds)) {
|
||||
if (Objects.nonNull(dataSetId) && dataSetId > 0) {
|
||||
Set<Long> result = new HashSet<>();
|
||||
result.add(viewId);
|
||||
result.add(dataSetId);
|
||||
return result;
|
||||
}
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
if (Objects.nonNull(detectViewIds)) {
|
||||
detectViewIds = detectViewIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
|
||||
if (Objects.nonNull(detectDataSetIds)) {
|
||||
detectDataSetIds = detectDataSetIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
if (Objects.nonNull(viewId) && viewId > 0 && Objects.nonNull(detectViewIds)) {
|
||||
if (detectViewIds.contains(viewId)) {
|
||||
if (Objects.nonNull(dataSetId) && dataSetId > 0 && Objects.nonNull(detectDataSetIds)) {
|
||||
if (detectDataSetIds.contains(dataSetId)) {
|
||||
Set<Long> result = new HashSet<>();
|
||||
result.add(viewId);
|
||||
result.add(dataSetId);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return detectViewIds;
|
||||
return detectDataSetIds;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,6 @@ import java.util.Set;
|
||||
*/
|
||||
public interface MatchStrategy<T> {
|
||||
|
||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds);
|
||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
|
||||
|
||||
}
|
||||
@@ -27,13 +27,13 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
Agent agent = queryContext.getAgent();
|
||||
if (agent == null || CollectionUtils.isEmpty(agent.getViewIds())) {
|
||||
if (agent == null || CollectionUtils.isEmpty(agent.getDataSetIds())) {
|
||||
return;
|
||||
}
|
||||
if (Agent.containsAllModel(agent.getViewIds())) {
|
||||
if (Agent.containsAllModel(agent.getDataSetIds())) {
|
||||
return;
|
||||
}
|
||||
Set<Long> viewIds = agent.getViewIds();
|
||||
Set<Long> viewIds = agent.getDataSetIds();
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
clearOtherSchemaElementMatch(viewIds, schemaMapInfo);
|
||||
for (Long viewId : viewIds) {
|
||||
@@ -47,7 +47,7 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
}
|
||||
|
||||
private void clearOtherSchemaElementMatch(Set<Long> viewIds, SchemaMapInfo schemaMapInfo) {
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getDataSetElementMatches().entrySet()) {
|
||||
if (!viewIds.contains(entry.getKey())) {
|
||||
entry.getValue().clear();
|
||||
}
|
||||
@@ -69,7 +69,7 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
.name(String.valueOf(filter.getValue()))
|
||||
.type(SchemaElementType.VALUE)
|
||||
.bizName(filter.getBizName())
|
||||
.view(viewId)
|
||||
.dataSet(viewId)
|
||||
.build();
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
|
||||
@@ -32,7 +32,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
||||
Set<Long> detectViewIds) {
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||
|
||||
@@ -57,9 +57,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
|
||||
SearchService.SEARCH_SIZE, detectViewIds);
|
||||
SearchService.SEARCH_SIZE, detectDataSetIds);
|
||||
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
|
||||
detectSegment, SEARCH_SIZE, detectViewIds);
|
||||
detectSegment, SEARCH_SIZE, detectDataSetIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||
@@ -93,7 +93,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
@@ -38,12 +38,12 @@ public class JavaLLMProxy implements LLMProxy {
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
|
||||
public LLMResp query2sql(LLMReq llmReq, Long dataSetId) {
|
||||
|
||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||
String modelName = llmReq.getSchema().getViewName();
|
||||
LLMResp result = sqlGeneration.generation(llmReq, viewId);
|
||||
String modelName = llmReq.getSchema().getDataSetName();
|
||||
LLMResp result = sqlGeneration.generation(llmReq, dataSetId);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
return result;
|
||||
|
||||
@@ -15,7 +15,7 @@ public interface LLMProxy {
|
||||
|
||||
boolean isSkip(QueryContext queryContext);
|
||||
|
||||
LLMResp query2sql(LLMReq llmReq, Long viewId);
|
||||
LLMResp query2sql(LLMReq llmReq, Long dataSetId);
|
||||
|
||||
FunctionResp requestFunction(FunctionReq functionReq);
|
||||
|
||||
|
||||
@@ -48,10 +48,10 @@ public class PythonLLMProxy implements LLMProxy {
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
|
||||
public LLMResp query2sql(LLMReq llmReq, Long dataSetId) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, viewId:{},llmReq:{}", viewId, llmReq);
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
log.info("requestLLM request, dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
try {
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
return QueryType.ID;
|
||||
}
|
||||
//1. entity queryType
|
||||
Long viewId = parseInfo.getViewId();
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
@@ -59,12 +59,12 @@ public class QueryTypeParser implements SemanticParser {
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isNotEmpty(whereFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(viewId).stream().map(SchemaElement::getName)
|
||||
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
|
||||
return QueryType.ID;
|
||||
}
|
||||
Set<String> tags = semanticSchema.getTags(viewId).stream().map(SchemaElement::getName)
|
||||
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
|
||||
return QueryType.TAG;
|
||||
@@ -73,7 +73,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
//2. metric queryType
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(viewId);
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
|
||||
@@ -55,13 +55,13 @@ public abstract class PluginParser implements SemanticParser {
|
||||
|
||||
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
|
||||
Plugin plugin = pluginRecallResult.getPlugin();
|
||||
Set<Long> viewIds = pluginRecallResult.getViewIds();
|
||||
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
|
||||
if (plugin.isContainsAllModel()) {
|
||||
viewIds = Sets.newHashSet(-1L);
|
||||
dataSetIds = Sets.newHashSet(-1L);
|
||||
}
|
||||
for (Long viewId : viewIds) {
|
||||
for (Long dataSetId : dataSetIds) {
|
||||
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(viewId, plugin,
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
|
||||
queryContext, pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
@@ -74,19 +74,19 @@ public abstract class PluginParser implements SemanticParser {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext);
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long viewId, Plugin plugin,
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, Plugin plugin,
|
||||
QueryContext queryContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
|
||||
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||
if (viewId == null && !CollectionUtils.isEmpty(plugin.getViewList())) {
|
||||
viewId = plugin.getViewList().get(0);
|
||||
if (dataSetId == null && !CollectionUtils.isEmpty(plugin.getDataSetList())) {
|
||||
dataSetId = plugin.getDataSetList().get(0);
|
||||
}
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
}
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||
semanticParseInfo.setView(queryContext.getSemanticSchema().getView(viewId));
|
||||
semanticParseInfo.setDataSet(queryContext.getSemanticSchema().getDataSet(dataSetId));
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||
pluginParseResult.setPlugin(plugin);
|
||||
|
||||
@@ -57,15 +57,15 @@ public class EmbeddingRecallParser extends PluginParser {
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
||||
log.info("embedding plugin resolve: {}", pair);
|
||||
if (pair.getLeft()) {
|
||||
Set<Long> viewList = pair.getRight();
|
||||
if (CollectionUtils.isEmpty(viewList)) {
|
||||
Set<Long> dataSetList = pair.getRight();
|
||||
if (CollectionUtils.isEmpty(dataSetList)) {
|
||||
continue;
|
||||
}
|
||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||
double distance = embeddingRetrieval.getDistance();
|
||||
double score = queryContext.getQueryText().length() * (1 - distance);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).viewIds(viewList).score(score).distance(distance).build();
|
||||
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
|
||||
@@ -57,19 +57,19 @@ public class FunctionCallParser extends PluginParser {
|
||||
plugin.setParseMode(ParseMode.FUNCTION_CALL);
|
||||
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
|
||||
if (pluginResolveResult.getLeft()) {
|
||||
Set<Long> viewList = pluginResolveResult.getRight();
|
||||
if (CollectionUtils.isEmpty(viewList)) {
|
||||
Set<Long> dataSetList = pluginResolveResult.getRight();
|
||||
if (CollectionUtils.isEmpty(dataSetList)) {
|
||||
return null;
|
||||
}
|
||||
double score = queryContext.getQueryText().length();
|
||||
return PluginRecallResult.builder().plugin(plugin).viewIds(viewList).score(score).build();
|
||||
return PluginRecallResult.builder().plugin(plugin).dataSetIds(dataSetList).score(score).build();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public FunctionResp functionCall(QueryContext queryContext) {
|
||||
List<PluginParseConfig> pluginToFunctionCall =
|
||||
getPluginToFunctionCall(queryContext.getViewId(), queryContext);
|
||||
getPluginToFunctionCall(queryContext.getDataSetId(), queryContext);
|
||||
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
|
||||
log.info("function call parser, plugin is empty, skip");
|
||||
return null;
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ViewMatchResult {
|
||||
public class DataSetMatchResult {
|
||||
private Integer count = 0;
|
||||
private double maxSimilarity;
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
public interface ViewResolver {
|
||||
public interface DataSetResolver {
|
||||
|
||||
Long resolve(QueryContext queryContext, Set<Long> restrictiveModels);
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class HeuristicDataSetResolver implements DataSetResolver {
|
||||
|
||||
protected static Long selectDataSetBySchemaElementMatchScore(Map<Long, SemanticQuery> dataSetQueryModes,
|
||||
SchemaMapInfo schemaMap) {
|
||||
//dataSet count priority
|
||||
Long dataSetIdByDataSetCount = getDataSetIdByMatchDataSetScore(schemaMap);
|
||||
if (Objects.nonNull(dataSetIdByDataSetCount)) {
|
||||
log.info("selectDataSet by dataSet count:{}", dataSetIdByDataSetCount);
|
||||
return dataSetIdByDataSetCount;
|
||||
}
|
||||
|
||||
Map<Long, DataSetMatchResult> dataSetTypeMap = getDataSetTypeMap(schemaMap);
|
||||
if (dataSetTypeMap.size() == 1) {
|
||||
Long dataSetSelect = new ArrayList<>(dataSetTypeMap.entrySet()).get(0).getKey();
|
||||
if (dataSetQueryModes.containsKey(dataSetSelect)) {
|
||||
log.info("selectDataSet with only one DataSet [{}]", dataSetSelect);
|
||||
return dataSetSelect;
|
||||
}
|
||||
} else {
|
||||
Map.Entry<Long, DataSetMatchResult> maxDataSet = dataSetTypeMap.entrySet().stream()
|
||||
.filter(entry -> dataSetQueryModes.containsKey(entry.getKey()))
|
||||
.sorted((o1, o2) -> {
|
||||
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
||||
if (difference == 0) {
|
||||
return (int) ((o2.getValue().getMaxSimilarity()
|
||||
- o1.getValue().getMaxSimilarity()) * 100);
|
||||
}
|
||||
return difference;
|
||||
}).findFirst().orElse(null);
|
||||
if (maxDataSet != null) {
|
||||
log.info("selectDataSet with multiple DataSets [{}]", maxDataSet.getKey());
|
||||
return maxDataSet.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static Long getDataSetIdByMatchDataSetScore(SchemaMapInfo schemaMap) {
|
||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = schemaMap.getDataSetElementMatches();
|
||||
// calculate dataSet match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||
Map<Long, Double> dataSetIdToDataSetScore = new HashMap<>();
|
||||
if (Objects.nonNull(dataSetElementMatches)) {
|
||||
for (Entry<Long, List<SchemaElementMatch>> dataSetElementMatch : dataSetElementMatches.entrySet()) {
|
||||
Long dataSetId = dataSetElementMatch.getKey();
|
||||
List<Double> dataSetMatchesScore = dataSetElementMatch.getValue().stream()
|
||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||
.filter(elementMatch -> SchemaElementType.DATASET.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(dataSetMatchesScore)) {
|
||||
// get sum of dataSet match score
|
||||
double score = dataSetMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||
dataSetIdToDataSetScore.put(dataSetId, score);
|
||||
}
|
||||
}
|
||||
Entry<Long, Double> maxDataSetScore = dataSetIdToDataSetScore.entrySet().stream()
|
||||
.max(Comparator.comparingDouble(Entry::getValue)).orElse(null);
|
||||
log.info("maxDataSetCount:{},dataSetIdToDataSetCount:{}", maxDataSetScore, dataSetIdToDataSetScore);
|
||||
if (Objects.nonNull(maxDataSetScore)) {
|
||||
return maxDataSetScore.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Map<Long, DataSetMatchResult> getDataSetTypeMap(SchemaMapInfo schemaMap) {
|
||||
Map<Long, DataSetMatchResult> dataSetCount = new HashMap<>();
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getDataSetElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||
if (!dataSetCount.containsKey(entry.getKey())) {
|
||||
dataSetCount.put(entry.getKey(), new DataSetMatchResult());
|
||||
}
|
||||
DataSetMatchResult dataSetMatchResult = dataSetCount.get(entry.getKey());
|
||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||
schemaElementMatches.stream()
|
||||
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
||||
schemaElementMatch.getElement().getType()));
|
||||
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
||||
.sorted((o1, o2) ->
|
||||
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
||||
).findFirst().orElse(null);
|
||||
if (schemaElementMatchMax != null) {
|
||||
dataSetMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||
}
|
||||
dataSetMatchResult.setCount(schemaElementTypes.size());
|
||||
|
||||
}
|
||||
}
|
||||
return dataSetCount;
|
||||
}
|
||||
|
||||
public Long resolve(QueryContext queryContext, Set<Long> agentDataSetIds) {
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
|
||||
Long dataSetId = queryContext.getDataSetId();
|
||||
if (Objects.nonNull(dataSetId) && dataSetId > 0) {
|
||||
if (CollectionUtils.isEmpty(agentDataSetIds) || agentDataSetIds.contains(dataSetId)) {
|
||||
return dataSetId;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
|
||||
matchedDataSets.retainAll(agentDataSetIds);
|
||||
}
|
||||
Map<Long, SemanticQuery> dataSetQueryModes = new HashMap<>();
|
||||
for (Long dataSetIds : matchedDataSets) {
|
||||
dataSetQueryModes.put(dataSetIds, null);
|
||||
}
|
||||
if (dataSetQueryModes.size() == 1) {
|
||||
return dataSetQueryModes.keySet().stream().findFirst().get();
|
||||
}
|
||||
return selectDataSetBySchemaElementMatchScore(dataSetQueryModes, mapInfo);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,138 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class HeuristicViewResolver implements ViewResolver {
|
||||
|
||||
protected static Long selectViewBySchemaElementMatchScore(Map<Long, SemanticQuery> viewQueryModes,
|
||||
SchemaMapInfo schemaMap) {
|
||||
//view count priority
|
||||
Long viewIdByViewCount = getViewIdByMatchViewScore(schemaMap);
|
||||
if (Objects.nonNull(viewIdByViewCount)) {
|
||||
log.info("selectView by view count:{}", viewIdByViewCount);
|
||||
return viewIdByViewCount;
|
||||
}
|
||||
|
||||
Map<Long, ViewMatchResult> viewTypeMap = getViewTypeMap(schemaMap);
|
||||
if (viewTypeMap.size() == 1) {
|
||||
Long viewSelect = new ArrayList<>(viewTypeMap.entrySet()).get(0).getKey();
|
||||
if (viewQueryModes.containsKey(viewSelect)) {
|
||||
log.info("selectView with only one View [{}]", viewSelect);
|
||||
return viewSelect;
|
||||
}
|
||||
} else {
|
||||
Map.Entry<Long, ViewMatchResult> maxView = viewTypeMap.entrySet().stream()
|
||||
.filter(entry -> viewQueryModes.containsKey(entry.getKey()))
|
||||
.sorted((o1, o2) -> {
|
||||
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
||||
if (difference == 0) {
|
||||
return (int) ((o2.getValue().getMaxSimilarity()
|
||||
- o1.getValue().getMaxSimilarity()) * 100);
|
||||
}
|
||||
return difference;
|
||||
}).findFirst().orElse(null);
|
||||
if (maxView != null) {
|
||||
log.info("selectView with multiple Views [{}]", maxView.getKey());
|
||||
return maxView.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static Long getViewIdByMatchViewScore(SchemaMapInfo schemaMap) {
|
||||
Map<Long, List<SchemaElementMatch>> viewElementMatches = schemaMap.getViewElementMatches();
|
||||
// calculate view match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||
Map<Long, Double> viewIdToViewScore = new HashMap<>();
|
||||
if (Objects.nonNull(viewElementMatches)) {
|
||||
for (Entry<Long, List<SchemaElementMatch>> viewElementMatch : viewElementMatches.entrySet()) {
|
||||
Long viewId = viewElementMatch.getKey();
|
||||
List<Double> viewMatchesScore = viewElementMatch.getValue().stream()
|
||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||
.filter(elementMatch -> SchemaElementType.VIEW.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(viewMatchesScore)) {
|
||||
// get sum of view match score
|
||||
double score = viewMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||
viewIdToViewScore.put(viewId, score);
|
||||
}
|
||||
}
|
||||
Entry<Long, Double> maxViewScore = viewIdToViewScore.entrySet().stream()
|
||||
.max(Comparator.comparingDouble(Entry::getValue)).orElse(null);
|
||||
log.info("maxViewCount:{},viewIdToViewCount:{}", maxViewScore, viewIdToViewScore);
|
||||
if (Objects.nonNull(maxViewScore)) {
|
||||
return maxViewScore.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Map<Long, ViewMatchResult> getViewTypeMap(SchemaMapInfo schemaMap) {
|
||||
Map<Long, ViewMatchResult> viewCount = new HashMap<>();
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getViewElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||
if (!viewCount.containsKey(entry.getKey())) {
|
||||
viewCount.put(entry.getKey(), new ViewMatchResult());
|
||||
}
|
||||
ViewMatchResult viewMatchResult = viewCount.get(entry.getKey());
|
||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||
schemaElementMatches.stream()
|
||||
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
||||
schemaElementMatch.getElement().getType()));
|
||||
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
||||
.sorted((o1, o2) ->
|
||||
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
||||
).findFirst().orElse(null);
|
||||
if (schemaElementMatchMax != null) {
|
||||
viewMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||
}
|
||||
viewMatchResult.setCount(schemaElementTypes.size());
|
||||
|
||||
}
|
||||
}
|
||||
return viewCount;
|
||||
}
|
||||
|
||||
public Long resolve(QueryContext queryContext, Set<Long> agentViewIds) {
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
Set<Long> matchedViews = mapInfo.getMatchedViewInfos();
|
||||
Long viewId = queryContext.getViewId();
|
||||
if (Objects.nonNull(viewId) && viewId > 0) {
|
||||
if (CollectionUtils.isEmpty(agentViewIds) || agentViewIds.contains(viewId)) {
|
||||
return viewId;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(agentViewIds)) {
|
||||
matchedViews.retainAll(agentViewIds);
|
||||
}
|
||||
Map<Long, SemanticQuery> viewQueryModes = new HashMap<>();
|
||||
for (Long viewIds : matchedViews) {
|
||||
viewQueryModes.put(viewIds, null);
|
||||
}
|
||||
if (viewQueryModes.size() == 1) {
|
||||
return viewQueryModes.keySet().stream().findFirst().get();
|
||||
}
|
||||
return selectViewBySchemaElementMatchScore(viewQueryModes, mapInfo);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -61,20 +61,20 @@ public class LLMRequestService {
|
||||
return false;
|
||||
}
|
||||
|
||||
public Long getViewId(QueryContext queryCtx) {
|
||||
public Long getDataSetId(QueryContext queryCtx) {
|
||||
Agent agent = queryCtx.getAgent();
|
||||
Set<Long> agentViewIds = new HashSet<>();
|
||||
Set<Long> agentDataSetIds = new HashSet<>();
|
||||
if (Objects.nonNull(agent)) {
|
||||
agentViewIds = agent.getViewIds(AgentToolType.NL2SQL_LLM);
|
||||
agentDataSetIds = agent.getDataSetIds(AgentToolType.NL2SQL_LLM);
|
||||
}
|
||||
if (Agent.containsAllModel(agentViewIds)) {
|
||||
agentViewIds = new HashSet<>();
|
||||
if (Agent.containsAllModel(agentDataSetIds)) {
|
||||
agentDataSetIds = new HashSet<>();
|
||||
}
|
||||
ViewResolver viewResolver = ComponentFactory.getModelResolver();
|
||||
return viewResolver.resolve(queryCtx, agentViewIds);
|
||||
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
|
||||
return dataSetResolver.resolve(queryCtx, agentDataSetIds);
|
||||
}
|
||||
|
||||
public NL2SQLTool getParserTool(QueryContext queryCtx, Long viewId) {
|
||||
public NL2SQLTool getParserTool(QueryContext queryCtx, Long dataSetId) {
|
||||
Agent agent = queryCtx.getAgent();
|
||||
if (Objects.isNull(agent)) {
|
||||
return null;
|
||||
@@ -82,19 +82,19 @@ public class LLMRequestService {
|
||||
List<NL2SQLTool> commonAgentTools = agent.getParserTools(AgentToolType.NL2SQL_LLM);
|
||||
Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
|
||||
.filter(tool -> {
|
||||
List<Long> viewIds = tool.getViewIds();
|
||||
if (Agent.containsAllModel(new HashSet<>(viewIds))) {
|
||||
List<Long> dataSetIds = tool.getDataSetIds();
|
||||
if (Agent.containsAllModel(new HashSet<>(dataSetIds))) {
|
||||
return true;
|
||||
}
|
||||
return viewIds.contains(viewId);
|
||||
return dataSetIds.contains(dataSetId);
|
||||
})
|
||||
.findFirst();
|
||||
return llmParserTool.orElse(null);
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long viewId,
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
|
||||
SemanticSchema semanticSchema, List<ElementValue> linkingValues) {
|
||||
Map<Long, String> viewIdToName = semanticSchema.getViewIdToName();
|
||||
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
|
||||
String queryText = queryCtx.getQueryText();
|
||||
|
||||
LLMReq llmReq = new LLMReq();
|
||||
@@ -103,12 +103,12 @@ public class LLMRequestService {
|
||||
llmReq.setFilterCondition(filterCondition);
|
||||
|
||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||
llmSchema.setViewName(viewIdToName.get(viewId));
|
||||
llmSchema.setDomainName(viewIdToName.get(viewId));
|
||||
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
||||
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
|
||||
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, viewId, llmParserConfig);
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, dataSetId, llmParserConfig);
|
||||
|
||||
String priorExts = getPriorExts(viewId, fieldNameList);
|
||||
String priorExts = getPriorExts(dataSetId, fieldNameList);
|
||||
llmReq.setPriorExts(priorExts);
|
||||
|
||||
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
|
||||
@@ -121,7 +121,7 @@ public class LLMRequestService {
|
||||
}
|
||||
llmReq.setLinking(linking);
|
||||
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, viewId);
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, dataSetId);
|
||||
if (StringUtils.isEmpty(currentDate)) {
|
||||
currentDate = DateUtils.getBeforeDate(0);
|
||||
}
|
||||
@@ -130,28 +130,28 @@ public class LLMRequestService {
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
public LLMResp requestLLM(LLMReq llmReq, Long viewId) {
|
||||
return ComponentFactory.getLLMProxy().query2sql(llmReq, viewId);
|
||||
public LLMResp requestLLM(LLMReq llmReq, Long dataSetId) {
|
||||
return ComponentFactory.getLLMProxy().query2sql(llmReq, dataSetId);
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long viewId,
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(queryCtx, viewId, llmParserConfig);
|
||||
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
|
||||
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, viewId);
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
|
||||
|
||||
results.addAll(fieldNameList);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
private String getPriorExts(Long viewId, List<String> fieldNameList) {
|
||||
private String getPriorExts(Long dataSetId, List<String> fieldNameList) {
|
||||
StringBuilder extraInfoSb = new StringBuilder();
|
||||
List<ViewSchemaResp> viewSchemaResps = semanticInterpreter.fetchViewSchema(
|
||||
Lists.newArrayList(viewId), true);
|
||||
if (!CollectionUtils.isEmpty(viewSchemaResps)) {
|
||||
ViewSchemaResp viewSchemaResp = viewSchemaResps.get(0);
|
||||
Map<String, String> fieldNameToDataFormatType = viewSchemaResp.getMetrics()
|
||||
List<DataSetSchemaResp> dataSetSchemaResps = semanticInterpreter.fetchDataSetSchema(
|
||||
Lists.newArrayList(dataSetId), true);
|
||||
if (!CollectionUtils.isEmpty(dataSetSchemaResps)) {
|
||||
DataSetSchemaResp dataSetSchemaResp = dataSetSchemaResps.get(0);
|
||||
Map<String, String> fieldNameToDataFormatType = dataSetSchemaResp.getMetrics()
|
||||
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
|
||||
.flatMap(metricSchemaResp -> {
|
||||
Set<Pair<String, String>> result = new HashSet<>();
|
||||
@@ -179,9 +179,9 @@ public class LLMRequestService {
|
||||
return extraInfoSb.toString();
|
||||
}
|
||||
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, Long viewId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, viewId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(viewId);
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, Long dataSetId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
@@ -201,21 +201,21 @@ public class LLMRequestService {
|
||||
return new ArrayList<>(valueMatches);
|
||||
}
|
||||
|
||||
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long viewId) {
|
||||
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
return semanticSchema.getDimensions(viewId).stream()
|
||||
return semanticSchema.getDimensions(dataSetId).stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long viewId, LLMParserConfig llmParserConfig) {
|
||||
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
Set<String> results = semanticSchema.getDimensions(viewId).stream()
|
||||
Set<String> results = semanticSchema.getDimensions(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
|
||||
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getMetricTopN())
|
||||
.map(entry -> entry.getName())
|
||||
@@ -225,9 +225,9 @@ public class LLMRequestService {
|
||||
return results;
|
||||
}
|
||||
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long viewId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, viewId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(viewId);
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long dataSetId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@@ -28,9 +28,9 @@ public class LLMResponseService {
|
||||
}
|
||||
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(LLMSqlQuery.QUERY_MODE);
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
parseInfo.setView(queryCtx.getSemanticSchema().getView(parseResult.getViewId()));
|
||||
parseInfo.setDataSet(queryCtx.getSemanticSchema().getDataSet(parseResult.getDataSetId()));
|
||||
NL2SQLTool commonAgentTool = parseResult.getCommonAgentTool();
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getViewId()));
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getDataSetId()));
|
||||
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, parseResult);
|
||||
|
||||
@@ -29,21 +29,21 @@ public class LLMSqlParser implements SemanticParser {
|
||||
}
|
||||
try {
|
||||
//2.get modelId from queryCtx and chatCtx.
|
||||
Long viewId = requestService.getViewId(queryCtx);
|
||||
if (viewId == null) {
|
||||
Long dataSetId = requestService.getDataSetId(queryCtx);
|
||||
if (dataSetId == null) {
|
||||
return;
|
||||
}
|
||||
//3.get agent tool and determine whether to skip this parser.
|
||||
NL2SQLTool commonAgentTool = requestService.getParserTool(queryCtx, viewId);
|
||||
NL2SQLTool commonAgentTool = requestService.getParserTool(queryCtx, dataSetId);
|
||||
if (Objects.isNull(commonAgentTool)) {
|
||||
log.info("no tool in this agent, skip {}", LLMSqlParser.class);
|
||||
return;
|
||||
}
|
||||
//4.construct a request, call the API for the large model, and retrieve the results.
|
||||
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, viewId);
|
||||
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, viewId, semanticSchema, linkingValues);
|
||||
LLMResp llmResp = requestService.requestLLM(llmReq, viewId);
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
|
||||
LLMResp llmResp = requestService.requestLLM(llmReq, dataSetId);
|
||||
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
@@ -52,7 +52,7 @@ public class LLMSqlParser implements SemanticParser {
|
||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
|
||||
ParseResult parseResult = ParseResult.builder()
|
||||
.viewId(viewId)
|
||||
.dataSetId(dataSetId)
|
||||
.commonAgentTool(commonAgentTool)
|
||||
.llmReq(llmReq)
|
||||
.llmResp(llmResp)
|
||||
|
||||
@@ -43,9 +43,9 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
@@ -41,9 +41,9 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
||||
//1.retriever sqlExamples
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import java.util.List;
|
||||
@NoArgsConstructor
|
||||
public class ParseResult {
|
||||
|
||||
private Long viewId;
|
||||
private Long dataSetId;
|
||||
|
||||
private LLMReq llmReq;
|
||||
|
||||
|
||||
@@ -12,9 +12,9 @@ public interface SqlGeneration {
|
||||
/***
|
||||
* generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
|
||||
* @param llmReq
|
||||
* @param viewId
|
||||
* @param dataSetId
|
||||
* @return
|
||||
*/
|
||||
LLMResp generation(LLMReq llmReq, Long viewId);
|
||||
LLMResp generation(LLMReq llmReq, Long dataSetId);
|
||||
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ public class SqlPromptGenerator {
|
||||
}
|
||||
|
||||
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
|
||||
String modelName = llmReq.getSchema().getViewName();
|
||||
String modelName = llmReq.getSchema().getDataSetName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
String currentDate = llmReq.getCurrentDate();
|
||||
|
||||
@@ -40,9 +40,9 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
|
||||
@@ -40,8 +40,8 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
|
||||
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
|
||||
@@ -58,13 +58,13 @@ public class AgentCheckParser implements SemanticParser {
|
||||
}
|
||||
}
|
||||
}
|
||||
if (CollectionUtils.isEmpty(tool.getViewIds())) {
|
||||
if (CollectionUtils.isEmpty(tool.getDataSetIds())) {
|
||||
return true;
|
||||
}
|
||||
if (tool.isContainsAllModel()) {
|
||||
return false;
|
||||
}
|
||||
return !tool.getViewIds().contains(query.getParseInfo().getViewId());
|
||||
return !tool.getDataSetIds().contains(query.getParseInfo().getDataSetId());
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
@@ -39,7 +39,7 @@ public class ContextInheritParser implements SemanticParser {
|
||||
SchemaElementType.VALUE, Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY, Arrays.asList(SchemaElementType.ENTITY)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.TAG, Arrays.asList(SchemaElementType.TAG)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.VIEW, Arrays.asList(SchemaElementType.VIEW)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.DATASET, Arrays.asList(SchemaElementType.DATASET)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ID, Arrays.asList(SchemaElementType.ID))
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
|
||||
@@ -48,12 +48,12 @@ public class ContextInheritParser implements SemanticParser {
|
||||
if (!shouldInherit(queryContext)) {
|
||||
return;
|
||||
}
|
||||
Long viewId = getMatchedView(queryContext, chatContext);
|
||||
if (viewId == null) {
|
||||
Long dataSetId = getMatchedDataSet(queryContext, chatContext);
|
||||
if (dataSetId == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
|
||||
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
|
||||
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
||||
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
||||
@@ -70,17 +70,17 @@ public class ContextInheritParser implements SemanticParser {
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(queryContext, chatContext);
|
||||
if (existSameQuery(query.getParseInfo().getViewId(), query.getQueryMode(), queryContext)) {
|
||||
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), queryContext)) {
|
||||
continue;
|
||||
}
|
||||
queryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean existSameQuery(Long viewId, String queryMode, QueryContext queryContext) {
|
||||
private boolean existSameQuery(Long dataSetId, String queryMode, QueryContext queryContext) {
|
||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||
if (semanticQuery.getQueryMode().equals(queryMode)
|
||||
&& semanticQuery.getParseInfo().getViewId().equals(viewId)) {
|
||||
&& semanticQuery.getParseInfo().getDataSetId().equals(dataSetId)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -109,16 +109,16 @@ public class ContextInheritParser implements SemanticParser {
|
||||
return metricModelQueries.size() == queryContext.getCandidateQueries().size();
|
||||
}
|
||||
|
||||
protected Long getMatchedView(QueryContext queryContext, ChatContext chatContext) {
|
||||
Long viewId = chatContext.getParseInfo().getViewId();
|
||||
if (viewId == null) {
|
||||
protected Long getMatchedDataSet(QueryContext queryContext, ChatContext chatContext) {
|
||||
Long dataSetId = chatContext.getParseInfo().getDataSetId();
|
||||
if (dataSetId == null) {
|
||||
return null;
|
||||
}
|
||||
Set<Long> queryViews = queryContext.getMapInfo().getMatchedViewInfos();
|
||||
if (queryViews.contains(viewId)) {
|
||||
return viewId;
|
||||
Set<Long> queryDataSets = queryContext.getMapInfo().getMatchedDataSetInfos();
|
||||
if (queryDataSets.contains(dataSetId)) {
|
||||
return dataSetId;
|
||||
}
|
||||
return viewId;
|
||||
return dataSetId;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -29,8 +29,8 @@ public class RuleSqlParser implements SemanticParser {
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
// iterate all schemaElementMatches to resolve query mode
|
||||
for (Long viewId : mapInfo.getMatchedViewInfos()) {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(viewId);
|
||||
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(queryContext, chatContext);
|
||||
|
||||
@@ -20,7 +20,7 @@ public class Plugin extends RecordInfo {
|
||||
*/
|
||||
private String type;
|
||||
|
||||
private List<Long> viewList = Lists.newArrayList();
|
||||
private List<Long> dataSetList = Lists.newArrayList();
|
||||
|
||||
/**
|
||||
* description, for parsing
|
||||
@@ -52,7 +52,7 @@ public class Plugin extends RecordInfo {
|
||||
}
|
||||
|
||||
public boolean isContainsAllModel() {
|
||||
return CollectionUtils.isNotEmpty(viewList) && viewList.contains(-1L);
|
||||
return CollectionUtils.isNotEmpty(dataSetList) && dataSetList.contains(-1L);
|
||||
}
|
||||
|
||||
public Long getDefaultMode() {
|
||||
|
||||
@@ -266,14 +266,14 @@ public class PluginManager {
|
||||
}
|
||||
|
||||
private static Set<Long> getPluginMatchedModel(Plugin plugin, QueryContext queryContext) {
|
||||
Set<Long> matchedViews = queryContext.getMapInfo().getMatchedViewInfos();
|
||||
Set<Long> matchedDataSets = queryContext.getMapInfo().getMatchedDataSetInfos();
|
||||
if (plugin.isContainsAllModel()) {
|
||||
return Sets.newHashSet(plugin.getDefaultMode());
|
||||
}
|
||||
List<Long> modelIds = plugin.getViewList();
|
||||
List<Long> modelIds = plugin.getDataSetList();
|
||||
Set<Long> pluginMatchedModel = Sets.newHashSet();
|
||||
for (Long modelId : modelIds) {
|
||||
if (matchedViews.contains(modelId)) {
|
||||
if (matchedDataSets.contains(modelId)) {
|
||||
pluginMatchedModel.add(modelId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ public class PluginRecallResult {
|
||||
|
||||
private Plugin plugin;
|
||||
|
||||
private Set<Long> viewIds;
|
||||
private Set<Long> dataSetIds;
|
||||
|
||||
private double score;
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ public class QueryContext {
|
||||
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Long viewId;
|
||||
private Long dataSetId;
|
||||
private User user;
|
||||
private boolean saveAnswer = true;
|
||||
private Integer agentId;
|
||||
|
||||
@@ -49,7 +49,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
explainSqlReq = ExplainSqlReq.builder()
|
||||
.queryTypeEnum(QueryType.SQL)
|
||||
.queryReq(QueryReqBuilder.buildS2SQLReq(
|
||||
sqlInfo.getCorrectS2SQL(), parseInfo.getViewId()
|
||||
sqlInfo.getCorrectS2SQL(), parseInfo.getDataSetId()
|
||||
))
|
||||
.build();
|
||||
} else {
|
||||
@@ -84,7 +84,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
|
||||
protected void convertBizNameToName(SemanticSchema semanticSchema, QueryStructReq queryStructReq) {
|
||||
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getViewId());
|
||||
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getDataSetId());
|
||||
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
|
||||
|
||||
List<Order> orders = queryStructReq.getOrders();
|
||||
|
||||
@@ -36,7 +36,7 @@ public class LLMReq {
|
||||
|
||||
private String domainName;
|
||||
|
||||
private String viewName;
|
||||
private String dataSetName;
|
||||
|
||||
private List<String> fieldNameList;
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ public class LLMSqlQuery extends LLMSemanticQuery {
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
String querySql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
QuerySqlReq querySQLReq = QueryReqBuilder.buildS2SQLReq(querySql, parseInfo.getViewId());
|
||||
QuerySqlReq querySQLReq = QueryReqBuilder.buildS2SQLReq(querySql, parseInfo.getDataSetId());
|
||||
SemanticQueryResp queryResp = semanticInterpreter.queryByS2SQL(querySQLReq, user);
|
||||
|
||||
log.info("queryByS2SQL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
|
||||
|
||||
@@ -79,7 +79,7 @@ public abstract class PluginSemanticQuery extends BaseSemanticQuery {
|
||||
if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) {
|
||||
for (ParamOption paramOption : webPage.getParamOptions()) {
|
||||
if (paramOption.getModelId() != null
|
||||
&& !parseInfo.getViewId().equals(paramOption.getModelId())) {
|
||||
&& !parseInfo.getDataSetId().equals(paramOption.getModelId())) {
|
||||
continue;
|
||||
}
|
||||
paramOptions.add(paramOption);
|
||||
|
||||
@@ -25,7 +25,7 @@ public class QueryMatcher {
|
||||
|
||||
public QueryMatcher() {
|
||||
for (SchemaElementType type : SchemaElementType.values()) {
|
||||
if (type.equals(SchemaElementType.VIEW)) {
|
||||
if (type.equals(SchemaElementType.DATASET)) {
|
||||
elementOptionMap.put(type, QueryMatchOption.optional());
|
||||
} else {
|
||||
elementOptionMap.put(type, QueryMatchOption.unused());
|
||||
|
||||
@@ -102,9 +102,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
|
||||
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
|
||||
Set<Long> viewIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
|
||||
.map(SchemaElement::getView).collect(Collectors.toSet());
|
||||
parseInfo.setView(semanticSchema.getView(viewIds.iterator().next()));
|
||||
Set<Long> dataSetIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
|
||||
.map(SchemaElement::getDataSet).collect(Collectors.toSet());
|
||||
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetIds.iterator().next()));
|
||||
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
|
||||
Map<Long, List<SchemaElementMatch>> id2Values = new HashMap<>();
|
||||
|
||||
@@ -189,7 +189,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
public QueryResult execute(User user) {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
|
||||
if (parseInfo.getViewId() == null || StringUtils.isEmpty(queryMode)
|
||||
if (parseInfo.getDataSetId() == null || StringUtils.isEmpty(queryMode)
|
||||
|| !QueryManager.containsRuleQuery(queryMode)) {
|
||||
// reach here some error may happen
|
||||
log.error("not find QueryMode");
|
||||
@@ -230,7 +230,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
public QueryResult multiStructExecute(User user) {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
|
||||
if (parseInfo.getViewId() != null || StringUtils.isEmpty(queryMode)
|
||||
if (parseInfo.getDataSetId() != null || StringUtils.isEmpty(queryMode)
|
||||
|| !QueryManager.containsRuleQuery(queryMode)) {
|
||||
// reach here some error may happen
|
||||
log.error("not find QueryMode");
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.chat.core.query.rule.metric;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.OPTIONAL;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_MOST;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import org.springframework.stereotype.Component;
|
||||
@Component
|
||||
public class MetricModelQuery extends MetricSemanticQuery {
|
||||
|
||||
@@ -14,7 +14,7 @@ public class MetricModelQuery extends MetricSemanticQuery {
|
||||
|
||||
public MetricModelQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(SchemaElementType.VIEW, OPTIONAL, AT_MOST, 1);
|
||||
queryMatcher.addOption(SchemaElementType.DATASET, OPTIONAL, AT_MOST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
package com.tencent.supersonic.chat.core.query.rule.metric;
|
||||
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.METRIC;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.METRIC;
|
||||
|
||||
@Slf4j
|
||||
public abstract class MetricSemanticQuery extends RuleSemanticQuery {
|
||||
@@ -38,8 +39,9 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
|
||||
super.fillParseInfo(queryContext, chatContext);
|
||||
parseInfo.setLimit(METRIC_MAX_RESULTS);
|
||||
if (parseInfo.getDateInfo() == null) {
|
||||
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(parseInfo.getViewId());
|
||||
TimeDefaultConfig timeDefaultConfig = viewSchema.getMetricTypeTimeDefaultConfig();
|
||||
DataSetSchema dataSetSchema =
|
||||
queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
|
||||
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig();
|
||||
DateConf dateInfo = new DateConf();
|
||||
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())) {
|
||||
int unit = timeDefaultConfig.getUnit();
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.core.query.rule.tag;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
@@ -25,19 +25,19 @@ public abstract class TagListQuery extends TagSemanticQuery {
|
||||
}
|
||||
|
||||
private void addEntityDetailAndOrderByMetric(QueryContext queryContext, SemanticParseInfo parseInfo) {
|
||||
Long viewId = parseInfo.getViewId();
|
||||
if (Objects.nonNull(viewId) && viewId > 0L) {
|
||||
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
|
||||
if (viewSchema != null && Objects.nonNull(viewSchema.getEntity())) {
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
if (Objects.nonNull(dataSetId) && dataSetId > 0L) {
|
||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
if (dataSetSchema != null && Objects.nonNull(dataSetSchema.getEntity())) {
|
||||
Set<SchemaElement> dimensions = new LinkedHashSet<>();
|
||||
Set<SchemaElement> metrics = new LinkedHashSet<>();
|
||||
Set<Order> orders = new LinkedHashSet<>();
|
||||
TagTypeDefaultConfig tagTypeDefaultConfig = viewSchema.getTagTypeDefaultConfig();
|
||||
TagTypeDefaultConfig tagTypeDefaultConfig = dataSetSchema.getTagTypeDefaultConfig();
|
||||
if (tagTypeDefaultConfig != null && tagTypeDefaultConfig.getDefaultDisplayInfo() != null) {
|
||||
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) {
|
||||
metrics = tagTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds()
|
||||
.stream().map(id -> {
|
||||
SchemaElement metric = viewSchema.getElement(SchemaElementType.METRIC, id);
|
||||
SchemaElement metric = dataSetSchema.getElement(SchemaElementType.METRIC, id);
|
||||
if (metric != null) {
|
||||
orders.add(new Order(metric.getBizName(), Constants.DESC_UPPER));
|
||||
}
|
||||
@@ -46,7 +46,7 @@ public abstract class TagListQuery extends TagSemanticQuery {
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) {
|
||||
dimensions = tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
|
||||
.map(id -> viewSchema.getElement(SchemaElementType.DIMENSION, id))
|
||||
.map(id -> dataSetSchema.getElement(SchemaElementType.DIMENSION, id))
|
||||
.filter(Objects::nonNull).collect(Collectors.toSet());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.core.query.rule.tag;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
@@ -42,8 +42,9 @@ public abstract class TagSemanticQuery extends RuleSemanticQuery {
|
||||
parseInfo.setQueryType(QueryType.TAG);
|
||||
parseInfo.setLimit(TAG_MAX_RESULTS);
|
||||
if (parseInfo.getDateInfo() == null) {
|
||||
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(parseInfo.getViewId());
|
||||
TimeDefaultConfig timeDefaultConfig = viewSchema.getTagTypeTimeDefaultConfig();
|
||||
DataSetSchema dataSetSchema =
|
||||
queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
|
||||
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
|
||||
DateConf dateInfo = new DateConf();
|
||||
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())) {
|
||||
int unit = timeDefaultConfig.getUnit();
|
||||
|
||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.core.query.semantic;
|
||||
|
||||
import com.google.common.cache.Cache;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -16,53 +16,53 @@ import java.util.concurrent.TimeUnit;
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticInterpreter implements SemanticInterpreter {
|
||||
|
||||
protected final Cache<String, List<ViewSchemaResp>> viewSchemaCache =
|
||||
protected final Cache<String, List<DataSetSchemaResp>> dataSetSchemaCache =
|
||||
CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.SECONDS).build();
|
||||
|
||||
@SneakyThrows
|
||||
public List<ViewSchemaResp> fetchViewSchema(List<Long> ids, Boolean cacheEnable) {
|
||||
public List<DataSetSchemaResp> fetchDataSetSchema(List<Long> ids, Boolean cacheEnable) {
|
||||
if (cacheEnable) {
|
||||
return viewSchemaCache.get(String.valueOf(ids), () -> {
|
||||
List<ViewSchemaResp> data = doFetchViewSchema(ids);
|
||||
viewSchemaCache.put(String.valueOf(ids), data);
|
||||
return dataSetSchemaCache.get(String.valueOf(ids), () -> {
|
||||
List<DataSetSchemaResp> data = doFetchDataSetSchema(ids);
|
||||
dataSetSchemaCache.put(String.valueOf(ids), data);
|
||||
return data;
|
||||
});
|
||||
}
|
||||
return doFetchViewSchema(ids);
|
||||
return doFetchDataSetSchema(ids);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ViewSchema getViewSchema(Long viewId, Boolean cacheEnable) {
|
||||
public DataSetSchema getDataSetSchema(Long dataSetId, Boolean cacheEnable) {
|
||||
List<Long> ids = new ArrayList<>();
|
||||
ids.add(viewId);
|
||||
List<ViewSchemaResp> viewSchemaResps = fetchViewSchema(ids, cacheEnable);
|
||||
if (!CollectionUtils.isEmpty(viewSchemaResps)) {
|
||||
Optional<ViewSchemaResp> viewSchemaResp = viewSchemaResps.stream()
|
||||
.filter(d -> d.getId().equals(viewId)).findFirst();
|
||||
if (viewSchemaResp.isPresent()) {
|
||||
ViewSchemaResp viewSchema = viewSchemaResp.get();
|
||||
return ViewSchemaBuilder.build(viewSchema);
|
||||
ids.add(dataSetId);
|
||||
List<DataSetSchemaResp> dataSetSchemaResps = fetchDataSetSchema(ids, cacheEnable);
|
||||
if (!CollectionUtils.isEmpty(dataSetSchemaResps)) {
|
||||
Optional<DataSetSchemaResp> dataSetSchemaResp = dataSetSchemaResps.stream()
|
||||
.filter(d -> d.getId().equals(dataSetId)).findFirst();
|
||||
if (dataSetSchemaResp.isPresent()) {
|
||||
DataSetSchemaResp dataSetSchema = dataSetSchemaResp.get();
|
||||
return DataSetSchemaBuilder.build(dataSetSchema);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ViewSchema> getViewSchema() {
|
||||
return getViewSchema(new ArrayList<>());
|
||||
public List<DataSetSchema> getDataSetSchema() {
|
||||
return getDataSetSchema(new ArrayList<>());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ViewSchema> getViewSchema(List<Long> ids) {
|
||||
List<ViewSchema> domainSchemaList = new ArrayList<>();
|
||||
public List<DataSetSchema> getDataSetSchema(List<Long> ids) {
|
||||
List<DataSetSchema> domainSchemaList = new ArrayList<>();
|
||||
|
||||
for (ViewSchemaResp resp : fetchViewSchema(ids, true)) {
|
||||
domainSchemaList.add(ViewSchemaBuilder.build(resp));
|
||||
for (DataSetSchemaResp resp : fetchDataSetSchema(ids, true)) {
|
||||
domainSchemaList.add(DataSetSchemaBuilder.build(resp));
|
||||
}
|
||||
|
||||
return domainSchemaList;
|
||||
}
|
||||
|
||||
protected abstract List<ViewSchemaResp> doFetchViewSchema(List<Long> ids);
|
||||
protected abstract List<DataSetSchemaResp> doFetchDataSetSchema(List<Long> ids);
|
||||
|
||||
}
|
||||
|
||||
@@ -5,13 +5,13 @@ import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.DimValueMap;
|
||||
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -23,19 +23,19 @@ import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ViewSchemaBuilder {
|
||||
public class DataSetSchemaBuilder {
|
||||
|
||||
public static ViewSchema build(ViewSchemaResp resp) {
|
||||
ViewSchema viewSchema = new ViewSchema();
|
||||
viewSchema.setQueryConfig(resp.getQueryConfig());
|
||||
SchemaElement model = SchemaElement.builder()
|
||||
.view(resp.getId())
|
||||
public static DataSetSchema build(DataSetSchemaResp resp) {
|
||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||
dataSetSchema.setQueryConfig(resp.getQueryConfig());
|
||||
SchemaElement dataSet = SchemaElement.builder()
|
||||
.dataSet(resp.getId())
|
||||
.id(resp.getId())
|
||||
.name(resp.getName())
|
||||
.bizName(resp.getBizName())
|
||||
.type(SchemaElementType.VIEW)
|
||||
.type(SchemaElementType.DATASET)
|
||||
.build();
|
||||
viewSchema.setView(model);
|
||||
dataSetSchema.setDataSet(dataSet);
|
||||
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
for (MetricSchemaResp metric : resp.getMetrics()) {
|
||||
@@ -43,7 +43,7 @@ public class ViewSchemaBuilder {
|
||||
List<String> alias = SchemaItem.getAliasList(metric.getAlias());
|
||||
|
||||
SchemaElement metricToAdd = SchemaElement.builder()
|
||||
.view(resp.getId())
|
||||
.dataSet(resp.getId())
|
||||
.model(metric.getModelId())
|
||||
.id(metric.getId())
|
||||
.name(metric.getName())
|
||||
@@ -57,7 +57,7 @@ public class ViewSchemaBuilder {
|
||||
metrics.add(metricToAdd);
|
||||
|
||||
}
|
||||
viewSchema.getMetrics().addAll(metrics);
|
||||
dataSetSchema.getMetrics().addAll(metrics);
|
||||
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||
@@ -84,7 +84,7 @@ public class ViewSchemaBuilder {
|
||||
|
||||
}
|
||||
SchemaElement dimToAdd = SchemaElement.builder()
|
||||
.view(resp.getId())
|
||||
.dataSet(resp.getId())
|
||||
.model(dim.getModelId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
@@ -97,7 +97,7 @@ public class ViewSchemaBuilder {
|
||||
dimensions.add(dimToAdd);
|
||||
|
||||
SchemaElement dimValueToAdd = SchemaElement.builder()
|
||||
.view(resp.getId())
|
||||
.dataSet(resp.getId())
|
||||
.model(dim.getModelId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
@@ -109,7 +109,7 @@ public class ViewSchemaBuilder {
|
||||
dimensionValues.add(dimValueToAdd);
|
||||
if (dim.getIsTag() == 1) {
|
||||
SchemaElement tagToAdd = SchemaElement.builder()
|
||||
.view(resp.getId())
|
||||
.dataSet(resp.getId())
|
||||
.model(dim.getModelId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
@@ -122,14 +122,14 @@ public class ViewSchemaBuilder {
|
||||
tags.add(tagToAdd);
|
||||
}
|
||||
}
|
||||
viewSchema.getDimensions().addAll(dimensions);
|
||||
viewSchema.getDimensionValues().addAll(dimensionValues);
|
||||
viewSchema.getTags().addAll(tags);
|
||||
dataSetSchema.getDimensions().addAll(dimensions);
|
||||
dataSetSchema.getDimensionValues().addAll(dimensionValues);
|
||||
dataSetSchema.getTags().addAll(tags);
|
||||
|
||||
DimSchemaResp dim = resp.getPrimaryKey();
|
||||
if (dim != null) {
|
||||
SchemaElement entity = SchemaElement.builder()
|
||||
.view(resp.getId())
|
||||
.dataSet(resp.getId())
|
||||
.model(dim.getModelId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
@@ -138,9 +138,9 @@ public class ViewSchemaBuilder {
|
||||
.useCnt(dim.getUseCnt())
|
||||
.alias(dim.getEntityAlias())
|
||||
.build();
|
||||
viewSchema.setEntity(entity);
|
||||
dataSetSchema.setEntity(entity);
|
||||
}
|
||||
return viewSchema;
|
||||
return dataSetSchema;
|
||||
}
|
||||
|
||||
private static List<RelatedSchemaElement> getRelateSchemaElement(MetricSchemaResp metricSchemaResp) {
|
||||
@@ -10,15 +10,15 @@ import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ViewFilterReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DataSetFilterReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
||||
import com.tencent.supersonic.headless.server.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import com.tencent.supersonic.headless.server.service.QueryService;
|
||||
@@ -44,7 +44,7 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
if (StringUtils.isNotBlank(queryStructReq.getCorrectS2SQL())) {
|
||||
QuerySqlReq querySqlReq = new QuerySqlReq();
|
||||
querySqlReq.setSql(queryStructReq.getCorrectS2SQL());
|
||||
querySqlReq.setViewId(queryStructReq.getViewId());
|
||||
querySqlReq.setDataSetId(queryStructReq.getDataSetId());
|
||||
querySqlReq.setParams(new ArrayList<>());
|
||||
return queryByS2SQL(querySqlReq, user);
|
||||
}
|
||||
@@ -68,11 +68,11 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ViewSchemaResp> doFetchViewSchema(List<Long> ids) {
|
||||
ViewFilterReq filter = new ViewFilterReq();
|
||||
filter.setViewIds(ids);
|
||||
public List<DataSetSchemaResp> doFetchDataSetSchema(List<Long> ids) {
|
||||
DataSetFilterReq filter = new DataSetFilterReq();
|
||||
filter.setDataSetIds(ids);
|
||||
schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
return schemaService.fetchViewSchema(filter);
|
||||
return schemaService.fetchDataSetSchema(filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -82,9 +82,9 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ViewResp> getViewList(Long domainId) {
|
||||
public List<DataSetResp> getDataSetList(Long domainId) {
|
||||
schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
return schemaService.getViewList(domainId);
|
||||
return schemaService.getDataSetList(domainId);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -106,8 +106,8 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ItemResp> getDomainViewTree() {
|
||||
return schemaService.getDomainViewTree();
|
||||
public List<ItemResp> getDomainDataSetTree() {
|
||||
return schemaService.getDomainDataSetTree();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -24,8 +24,8 @@ import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
@@ -250,17 +250,17 @@ public class RemoteSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<ViewSchemaResp> doFetchViewSchema(List<Long> ids) {
|
||||
protected List<DataSetSchemaResp> doFetchDataSetSchema(List<Long> ids) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ItemResp> getDomainViewTree() {
|
||||
public List<ItemResp> getDomainDataSetTree() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ViewResp> getViewList(Long domainId) {
|
||||
public List<DataSetResp> getDataSetList(Long domainId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.core.query.semantic;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
|
||||
@@ -15,14 +15,14 @@ import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A semantic layer provides a simplified and consistent view of data from multiple sources.
|
||||
* It abstracts away the complexity of the underlying data sources and provides a unified view
|
||||
* A semantic layer provides a simplified and consistent dataSet of data from multiple sources.
|
||||
* It abstracts away the complexity of the underlying data sources and provides a unified dataSet
|
||||
* of the data that is easier to understand and use.
|
||||
* <p>
|
||||
* The interface defines methods for getting metadata as well as querying data in the semantic layer.
|
||||
@@ -39,11 +39,11 @@ public interface SemanticInterpreter {
|
||||
|
||||
SemanticQueryResp queryByS2SQL(QuerySqlReq querySQLReq, User user);
|
||||
|
||||
List<ViewSchema> getViewSchema();
|
||||
List<DataSetSchema> getDataSetSchema();
|
||||
|
||||
List<ViewSchema> getViewSchema(List<Long> ids);
|
||||
List<DataSetSchema> getDataSetSchema(List<Long> ids);
|
||||
|
||||
ViewSchema getViewSchema(Long model, Boolean cacheEnable);
|
||||
DataSetSchema getDataSetSchema(Long model, Boolean cacheEnable);
|
||||
|
||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
|
||||
|
||||
@@ -53,10 +53,10 @@ public interface SemanticInterpreter {
|
||||
|
||||
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
||||
|
||||
List<ViewSchemaResp> fetchViewSchema(List<Long> ids, Boolean cacheEnable);
|
||||
List<DataSetSchemaResp> fetchDataSetSchema(List<Long> ids, Boolean cacheEnable);
|
||||
|
||||
List<ViewResp> getViewList(Long domainId);
|
||||
List<DataSetResp> getDataSetList(Long domainId);
|
||||
|
||||
List<ItemResp> getDomainViewTree();
|
||||
List<ItemResp> getDomainDataSetTree();
|
||||
|
||||
}
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
package com.tencent.supersonic.chat.core.utils;
|
||||
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.parser.JavaLLMProxy;
|
||||
import com.tencent.supersonic.chat.core.parser.LLMProxy;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.ViewResolver;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.DataSetResolver;
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
@Slf4j
|
||||
public class ComponentFactory {
|
||||
|
||||
private static SemanticInterpreter semanticInterpreter;
|
||||
private static LLMProxy llmProxy;
|
||||
private static ViewResolver modelResolver;
|
||||
private static DataSetResolver modelResolver;
|
||||
|
||||
public static SemanticInterpreter getSemanticLayer() {
|
||||
if (Objects.isNull(semanticInterpreter)) {
|
||||
@@ -44,9 +45,9 @@ public class ComponentFactory {
|
||||
return llmProxy;
|
||||
}
|
||||
|
||||
public static ViewResolver getModelResolver() {
|
||||
public static DataSetResolver getModelResolver() {
|
||||
if (Objects.isNull(modelResolver)) {
|
||||
modelResolver = init(ViewResolver.class);
|
||||
modelResolver = init(DataSetResolver.class);
|
||||
}
|
||||
return modelResolver;
|
||||
}
|
||||
|
||||
@@ -37,8 +37,8 @@ public class QueryReqBuilder {
|
||||
|
||||
public static QueryStructReq buildStructReq(SemanticParseInfo parseInfo) {
|
||||
QueryStructReq queryStructReq = new QueryStructReq();
|
||||
queryStructReq.setViewId(parseInfo.getViewId());
|
||||
queryStructReq.setViewName(parseInfo.getView().getName());
|
||||
queryStructReq.setDataSetId(parseInfo.getDataSetId());
|
||||
queryStructReq.setDataSetName(parseInfo.getDataSet().getName());
|
||||
queryStructReq.setQueryType(parseInfo.getQueryType());
|
||||
queryStructReq.setDateInfo(rewrite2Between(parseInfo.getDateInfo()));
|
||||
|
||||
@@ -119,7 +119,7 @@ public class QueryReqBuilder {
|
||||
for (Filter dimensionFilter : queryStructReq.getDimensionFilters()) {
|
||||
QueryStructReq req = new QueryStructReq();
|
||||
BeanUtils.copyProperties(queryStructReq, req);
|
||||
req.setViewId(parseInfo.getViewId());
|
||||
req.setDataSetId(parseInfo.getDataSetId());
|
||||
req.setDimensionFilters(Lists.newArrayList(dimensionFilter));
|
||||
queryStructReqs.add(req);
|
||||
}
|
||||
@@ -131,15 +131,15 @@ public class QueryReqBuilder {
|
||||
* convert to QueryS2SQLReq
|
||||
*
|
||||
* @param querySql
|
||||
* @param viewId
|
||||
* @param dataSetId
|
||||
* @return
|
||||
*/
|
||||
public static QuerySqlReq buildS2SQLReq(String querySql, Long viewId) {
|
||||
public static QuerySqlReq buildS2SQLReq(String querySql, Long dataSetId) {
|
||||
QuerySqlReq querySQLReq = new QuerySqlReq();
|
||||
if (Objects.nonNull(querySql)) {
|
||||
querySQLReq.setSql(querySql);
|
||||
}
|
||||
querySQLReq.setViewId(viewId);
|
||||
querySQLReq.setDataSetId(dataSetId);
|
||||
return querySQLReq;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,43 +1,44 @@
|
||||
package com.tencent.supersonic.chat.core.utils;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import com.tencent.supersonic.common.util.DatePeriodEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import java.util.Objects;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
public class S2SqlDateHelper {
|
||||
|
||||
public static String getReferenceDate(QueryContext queryContext, Long viewId) {
|
||||
public static String getReferenceDate(QueryContext queryContext, Long dataSetId) {
|
||||
String defaultDate = DateUtils.getBeforeDate(0);
|
||||
if (Objects.isNull(viewId)) {
|
||||
if (Objects.isNull(dataSetId)) {
|
||||
return defaultDate;
|
||||
}
|
||||
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
|
||||
if (viewSchema == null || viewSchema.getTagTypeTimeDefaultConfig() == null) {
|
||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
if (dataSetSchema == null || dataSetSchema.getTagTypeTimeDefaultConfig() == null) {
|
||||
return defaultDate;
|
||||
}
|
||||
TimeDefaultConfig tagTypeTimeDefaultConfig = viewSchema.getTagTypeTimeDefaultConfig();
|
||||
TimeDefaultConfig tagTypeTimeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
|
||||
return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig).getLeft();
|
||||
}
|
||||
|
||||
public static Pair<String, String> getStartEndDate(QueryContext queryContext,
|
||||
Long viewId, QueryType queryType) {
|
||||
Long dataSetId, QueryType queryType) {
|
||||
String defaultDate = DateUtils.getBeforeDate(0);
|
||||
if (Objects.isNull(viewId)) {
|
||||
if (Objects.isNull(dataSetId)) {
|
||||
return Pair.of(defaultDate, defaultDate);
|
||||
}
|
||||
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
|
||||
if (viewSchema == null) {
|
||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
if (dataSetSchema == null) {
|
||||
return Pair.of(defaultDate, defaultDate);
|
||||
}
|
||||
TimeDefaultConfig defaultConfig = viewSchema.getMetricTypeTimeDefaultConfig();
|
||||
TimeDefaultConfig defaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig();
|
||||
if (QueryType.TAG.equals(queryType)) {
|
||||
defaultConfig = viewSchema.getTagTypeTimeDefaultConfig();
|
||||
defaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
|
||||
}
|
||||
return getDefaultDate(defaultDate, defaultConfig);
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ public class SimilarQueryManager {
|
||||
embeddingQuery.setQuery(queryText);
|
||||
|
||||
Map<String, Object> metaData = new HashMap<>();
|
||||
metaData.put("modelId", similarQueryReq.getViewId());
|
||||
metaData.put("modelId", similarQueryReq.getDataSetId());
|
||||
metaData.put("agentId", similarQueryReq.getAgentId());
|
||||
embeddingQuery.setMetadata(metaData);
|
||||
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
@@ -23,7 +23,7 @@ import org.junit.jupiter.api.Test;
|
||||
class SchemaCorrectorTest {
|
||||
|
||||
private String json = "{\n"
|
||||
+ " \"viewId\": 1,\n"
|
||||
+ " \"dataSetId\": 1,\n"
|
||||
+ " \"llmReq\": {\n"
|
||||
+ " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n"
|
||||
+ " \"filterCondition\": {\n"
|
||||
@@ -31,7 +31,7 @@ class SchemaCorrectorTest {
|
||||
+ " },\n"
|
||||
+ " \"schema\": {\n"
|
||||
+ " \"domainName\": \"歌曲\",\n"
|
||||
+ " \"viewName\": \"歌曲\",\n"
|
||||
+ " \"dataSetName\": \"歌曲\",\n"
|
||||
+ " \"fieldNameList\": [\n"
|
||||
+ " \"商务组\",\n"
|
||||
+ " \"歌曲名\",\n"
|
||||
@@ -52,7 +52,7 @@ class SchemaCorrectorTest {
|
||||
+ " \"id\": \"y3LqVSRL\",\n"
|
||||
+ " \"name\": \"大模型语义解析\",\n"
|
||||
+ " \"type\": \"NL2SQL_LLM\",\n"
|
||||
+ " \"viewIds\": [\n"
|
||||
+ " \"dataSetIds\": [\n"
|
||||
+ " 1\n"
|
||||
+ " ]\n"
|
||||
+ " },\n"
|
||||
@@ -63,8 +63,8 @@ class SchemaCorrectorTest {
|
||||
|
||||
@Test
|
||||
void doCorrect() throws JsonProcessingException {
|
||||
Long viewId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(viewId);
|
||||
Long dataSetId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);
|
||||
|
||||
@@ -78,8 +78,8 @@ class SchemaCorrectorTest {
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setView(viewId);
|
||||
semanticParseInfo.setView(schemaElement);
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
semanticParseInfo.setDataSet(schemaElement);
|
||||
|
||||
|
||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
@@ -108,35 +108,35 @@ class SchemaCorrectorTest {
|
||||
|
||||
}
|
||||
|
||||
private QueryContext buildQueryContext(Long viewId) {
|
||||
private QueryContext buildQueryContext(Long dataSetId) {
|
||||
QueryContext queryContext = new QueryContext();
|
||||
List<ViewSchema> viewSchemaList = new ArrayList<>();
|
||||
ViewSchema viewSchema = new ViewSchema();
|
||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
viewSchema.setQueryConfig(queryConfig);
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setView(viewId);
|
||||
viewSchema.setView(schemaElement);
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setView(1L);
|
||||
element1.setDataSet(1L);
|
||||
element1.setName("歌曲名");
|
||||
dimensions.add(element1);
|
||||
|
||||
SchemaElement element2 = new SchemaElement();
|
||||
element2.setView(1L);
|
||||
element2.setDataSet(1L);
|
||||
element2.setName("商务组");
|
||||
dimensions.add(element2);
|
||||
|
||||
SchemaElement element3 = new SchemaElement();
|
||||
element3.setView(1L);
|
||||
element3.setDataSet(1L);
|
||||
element3.setName("发行日期");
|
||||
dimensions.add(element3);
|
||||
|
||||
viewSchema.setDimensions(dimensions);
|
||||
viewSchemaList.add(viewSchema);
|
||||
dataSetSchema.setDimensions(dimensions);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
SemanticSchema semanticSchema = new SemanticSchema(viewSchemaList);
|
||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||
queryContext.setSemanticSchema(semanticSchema);
|
||||
return queryContext;
|
||||
}
|
||||
|
||||
@@ -29,7 +29,7 @@ class LLMSqlParserTest {
|
||||
SchemaElement schemaElement = SchemaElement.builder()
|
||||
.bizName("singer_name")
|
||||
.name("歌手名")
|
||||
.view(2L)
|
||||
.dataSet(2L)
|
||||
.schemaValueMaps(schemaValueMaps)
|
||||
.build();
|
||||
dimensions.add(schemaElement);
|
||||
@@ -37,7 +37,7 @@ class LLMSqlParserTest {
|
||||
SchemaElement schemaElement2 = SchemaElement.builder()
|
||||
.bizName("publish_time")
|
||||
.name("发布时间")
|
||||
.view(2L)
|
||||
.dataSet(2L)
|
||||
.build();
|
||||
dimensions.add(schemaElement2);
|
||||
|
||||
@@ -47,7 +47,7 @@ class LLMSqlParserTest {
|
||||
SchemaElement metric = SchemaElement.builder()
|
||||
.bizName("play_count")
|
||||
.name("播放量")
|
||||
.view(2L)
|
||||
.dataSet(2L)
|
||||
.build();
|
||||
metrics.add(metric);
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.core.utils;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
@@ -11,60 +11,60 @@ import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
class S2SqlDateHelperTest {
|
||||
|
||||
@Test
|
||||
void getReferenceDate() {
|
||||
Long viewId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(viewId);
|
||||
Long dataSetId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
||||
|
||||
String referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, null);
|
||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
|
||||
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId);
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
|
||||
|
||||
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
|
||||
QueryConfig queryConfig = viewSchema.getQueryConfig();
|
||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
|
||||
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
||||
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||
timeDefaultConfig.setPeriod(Constants.DAY);
|
||||
timeDefaultConfig.setUnit(20);
|
||||
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
||||
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId);
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(20));
|
||||
|
||||
timeDefaultConfig.setUnit(1);
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId);
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
||||
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(1));
|
||||
|
||||
timeDefaultConfig.setUnit(-1);
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId);
|
||||
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
|
||||
Assert.assertNull(referenceDate);
|
||||
}
|
||||
|
||||
@Test
|
||||
void getStartEndDate() {
|
||||
Long viewId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(viewId);
|
||||
Long dataSetId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(dataSetId);
|
||||
|
||||
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, null, QueryType.TAG);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(0));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(0));
|
||||
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.TAG);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.TAG);
|
||||
Assert.assertNull(startEndDate.getLeft());
|
||||
Assert.assertNull(startEndDate.getRight());
|
||||
|
||||
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
|
||||
QueryConfig queryConfig = viewSchema.getQueryConfig();
|
||||
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
|
||||
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
||||
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||
timeDefaultConfig.setPeriod(Constants.DAY);
|
||||
@@ -72,49 +72,49 @@ class S2SqlDateHelperTest {
|
||||
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
||||
queryConfig.getMetricTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
|
||||
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.TAG);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.TAG);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
|
||||
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
|
||||
|
||||
timeDefaultConfig.setUnit(2);
|
||||
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
|
||||
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.TAG);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.TAG);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
|
||||
|
||||
timeDefaultConfig.setUnit(-1);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
||||
Assert.assertNull(startEndDate.getLeft());
|
||||
Assert.assertNull(startEndDate.getRight());
|
||||
|
||||
timeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||
timeDefaultConfig.setPeriod(Constants.DAY);
|
||||
timeDefaultConfig.setUnit(5);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC);
|
||||
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
|
||||
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(5));
|
||||
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(5));
|
||||
}
|
||||
|
||||
private QueryContext buildQueryContext(Long viewId) {
|
||||
private QueryContext buildQueryContext(Long dataSetId) {
|
||||
QueryContext queryContext = new QueryContext();
|
||||
List<ViewSchema> viewSchemaList = new ArrayList<>();
|
||||
ViewSchema viewSchema = new ViewSchema();
|
||||
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
|
||||
DataSetSchema dataSetSchema = new DataSetSchema();
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
viewSchema.setQueryConfig(queryConfig);
|
||||
dataSetSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setView(viewId);
|
||||
viewSchema.setView(schemaElement);
|
||||
viewSchemaList.add(viewSchema);
|
||||
schemaElement.setDataSet(dataSetId);
|
||||
dataSetSchema.setDataSet(schemaElement);
|
||||
dataSetSchemaList.add(dataSetSchema);
|
||||
|
||||
SemanticSchema semanticSchema = new SemanticSchema(viewSchemaList);
|
||||
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
|
||||
queryContext.setSemanticSchema(semanticSchema);
|
||||
return queryContext;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user