(improvement)(Headless)(Chat) Change View to DataSet (#782)

* (improvement)(Headless)(Chat) Change view to dataSet



---------

Co-authored-by: jolunoluo <jolunoluo@tencent.com>
This commit is contained in:
LXW
2024-03-04 11:48:41 +08:00
committed by GitHub
parent b29e429271
commit a41da3f5fe
184 changed files with 1628 additions and 1532 deletions

View File

@@ -12,9 +12,9 @@ import java.util.Optional;
import java.util.Set;
@Data
public class ViewSchema {
public class DataSetSchema {
private SchemaElement view;
private SchemaElement dataSet;
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();
private Set<SchemaElement> dimensionValues = new HashSet<>();
@@ -29,8 +29,8 @@ public class ViewSchema {
case ENTITY:
element = Optional.ofNullable(entity);
break;
case VIEW:
element = Optional.of(view);
case DATASET:
element = Optional.of(dataSet);
break;
case METRIC:
element = metrics.stream().filter(e -> e.getId() == elementID).findFirst();
@@ -61,8 +61,8 @@ public class ViewSchema {
case ENTITY:
element = Optional.ofNullable(entity);
break;
case VIEW:
element = Optional.of(view);
case DATASET:
element = Optional.of(dataSet);
break;
case METRIC:
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();

View File

@@ -9,25 +9,25 @@ import java.util.Set;
public class SchemaMapInfo {
private Map<Long, List<SchemaElementMatch>> viewElementMatches = new HashMap<>();
private Map<Long, List<SchemaElementMatch>> dataSetElementMatches = new HashMap<>();
public Set<Long> getMatchedViewInfos() {
return viewElementMatches.keySet();
public Set<Long> getMatchedDataSetInfos() {
return dataSetElementMatches.keySet();
}
public List<SchemaElementMatch> getMatchedElements(Long view) {
return viewElementMatches.getOrDefault(view, Lists.newArrayList());
public List<SchemaElementMatch> getMatchedElements(Long dataSet) {
return dataSetElementMatches.getOrDefault(dataSet, Lists.newArrayList());
}
public Map<Long, List<SchemaElementMatch>> getViewElementMatches() {
return viewElementMatches;
public Map<Long, List<SchemaElementMatch>> getDataSetElementMatches() {
return dataSetElementMatches;
}
public void setViewElementMatches(Map<Long, List<SchemaElementMatch>> viewElementMatches) {
this.viewElementMatches = viewElementMatches;
public void setDataSetElementMatches(Map<Long, List<SchemaElementMatch>> dataSetElementMatches) {
this.dataSetElementMatches = dataSetElementMatches;
}
public void setMatchedElements(Long view, List<SchemaElementMatch> elementMatches) {
viewElementMatches.put(view, elementMatches);
public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) {
dataSetElementMatches.put(dataSet, elementMatches);
}
}

View File

@@ -26,7 +26,7 @@ public class SemanticParseInfo {
private Integer id;
private String queryMode;
private SchemaElement view;
private SchemaElement dataSet;
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
private Set<SchemaElement> dimensions = new LinkedHashSet();
private SchemaElement entity;
@@ -72,15 +72,11 @@ public class SemanticParseInfo {
return metrics;
}
public Long getViewId() {
if (view == null) {
public Long getDataSetId() {
if (dataSet == null) {
return null;
}
return view.getView();
}
public SchemaElement getModel() {
return view;
return dataSet.getDataSet();
}
}

View File

@@ -2,6 +2,8 @@ package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import org.springframework.util.CollectionUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
@@ -9,18 +11,17 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.springframework.util.CollectionUtils;
public class SemanticSchema implements Serializable {
private List<ViewSchema> viewSchemaList;
private List<DataSetSchema> dataSetSchemaList;
public SemanticSchema(List<ViewSchema> viewSchemaList) {
this.viewSchemaList = viewSchemaList;
public SemanticSchema(List<DataSetSchema> dataSetSchemaList) {
this.dataSetSchemaList = dataSetSchemaList;
}
public void add(ViewSchema schema) {
viewSchemaList.add(schema);
public void add(DataSetSchema schema) {
dataSetSchemaList.add(schema);
}
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
@@ -30,8 +31,8 @@ public class SemanticSchema implements Serializable {
case ENTITY:
element = getElementsById(elementID, getEntities());
break;
case VIEW:
element = getElementsById(elementID, getViews());
case DATASET:
element = getElementsById(elementID, getDataSets());
break;
case METRIC:
element = getElementsById(elementID, getMetrics());
@@ -52,26 +53,26 @@ public class SemanticSchema implements Serializable {
}
}
public Map<Long, String> getViewIdToName() {
return viewSchemaList.stream()
.collect(Collectors.toMap(a -> a.getView().getId(), a -> a.getView().getName(), (k1, k2) -> k1));
public Map<Long, String> getDataSetIdToName() {
return dataSetSchemaList.stream()
.collect(Collectors.toMap(a -> a.getDataSet().getId(), a -> a.getDataSet().getName(), (k1, k2) -> k1));
}
public List<SchemaElement> getDimensionValues() {
List<SchemaElement> dimensionValues = new ArrayList<>();
viewSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
dataSetSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
return dimensionValues;
}
public List<SchemaElement> getDimensions() {
List<SchemaElement> dimensions = new ArrayList<>();
viewSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
dataSetSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
return dimensions;
}
public List<SchemaElement> getDimensions(Long viewId) {
public List<SchemaElement> getDimensions(Long dataSetId) {
List<SchemaElement> dimensions = getDimensions();
return getElementsByViewId(viewId, dimensions);
return getElementsByDataSetId(dataSetId, dimensions);
}
public SchemaElement getDimension(Long id) {
@@ -82,43 +83,43 @@ public class SemanticSchema implements Serializable {
public List<SchemaElement> getTags() {
List<SchemaElement> tags = new ArrayList<>();
viewSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
dataSetSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
return tags;
}
public List<SchemaElement> getTags(Long viewId) {
public List<SchemaElement> getTags(Long dataSetId) {
List<SchemaElement> tags = new ArrayList<>();
viewSchemaList.stream().filter(schemaElement ->
viewId.equals(schemaElement.getView().getView()))
dataSetSchemaList.stream().filter(schemaElement ->
dataSetId.equals(schemaElement.getDataSet().getDataSet()))
.forEach(d -> tags.addAll(d.getTags()));
return tags;
}
public List<SchemaElement> getMetrics() {
List<SchemaElement> metrics = new ArrayList<>();
viewSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
dataSetSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
return metrics;
}
public List<SchemaElement> getMetrics(Long viewId) {
public List<SchemaElement> getMetrics(Long dataSetId) {
List<SchemaElement> metrics = getMetrics();
return getElementsByViewId(viewId, metrics);
return getElementsByDataSetId(dataSetId, metrics);
}
public List<SchemaElement> getEntities() {
List<SchemaElement> entities = new ArrayList<>();
viewSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
dataSetSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
return entities;
}
public List<SchemaElement> getEntities(Long viewId) {
public List<SchemaElement> getEntities(Long dataSetId) {
List<SchemaElement> entities = getEntities();
return getElementsByViewId(viewId, entities);
return getElementsByDataSetId(dataSetId, entities);
}
private List<SchemaElement> getElementsByViewId(Long viewId, List<SchemaElement> elements) {
private List<SchemaElement> getElementsByDataSetId(Long dataSetId, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> viewId.equals(schemaElement.getView()))
.filter(schemaElement -> dataSetId.equals(schemaElement.getDataSet()))
.collect(Collectors.toList());
}
@@ -128,30 +129,30 @@ public class SemanticSchema implements Serializable {
.findFirst();
}
public SchemaElement getView(Long viewId) {
List<SchemaElement> views = getViews();
return getElementsById(viewId, views).orElse(null);
public SchemaElement getDataSet(Long dataSetId) {
List<SchemaElement> dataSets = getDataSets();
return getElementsById(dataSetId, dataSets).orElse(null);
}
public List<SchemaElement> getViews() {
List<SchemaElement> views = new ArrayList<>();
viewSchemaList.stream().forEach(d -> views.add(d.getView()));
return views;
public List<SchemaElement> getDataSets() {
List<SchemaElement> dataSets = new ArrayList<>();
dataSetSchemaList.stream().forEach(d -> dataSets.add(d.getDataSet()));
return dataSets;
}
public Map<String, String> getBizNameToName(Long viewId) {
public Map<String, String> getBizNameToName(Long dataSetId) {
List<SchemaElement> allElements = new ArrayList<>();
allElements.addAll(getDimensions(viewId));
allElements.addAll(getMetrics(viewId));
allElements.addAll(getDimensions(dataSetId));
allElements.addAll(getMetrics(dataSetId));
return allElements.stream()
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
}
public Map<Long, ViewSchema> getViewSchemaMap() {
if (CollectionUtils.isEmpty(viewSchemaList)) {
public Map<Long, DataSetSchema> getDataSetSchemaMap() {
if (CollectionUtils.isEmpty(dataSetSchemaList)) {
return new HashMap<>();
}
return viewSchemaList.stream().collect(Collectors.toMap(viewSchema
-> viewSchema.getView().getView(), viewSchema -> viewSchema));
return dataSetSchemaList.stream().collect(Collectors.toMap(dataSetSchema
-> dataSetSchema.getDataSet().getDataSet(), dataSetSchema -> dataSetSchema));
}
}

View File

@@ -13,7 +13,7 @@ public class PluginQueryReq {
private String type;
private String view;
private String dataSet;
private String pattern;

View File

@@ -7,7 +7,7 @@ import lombok.Data;
public class QueryReq {
private String queryText;
private Integer chatId;
private Long viewId;
private Long dataSetId;
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;

View File

@@ -18,7 +18,7 @@ public class SimilarQueryReq {
private String queryText;
private Long viewId;
private Long dataSetId;
private Integer agentId;

View File

@@ -6,7 +6,7 @@ import java.io.Serializable;
import java.util.List;
@Data
public class ViewInfo extends DataInfo implements Serializable {
public class DataSetInfo extends DataInfo implements Serializable {
private List<String> words;
private String primaryKey;

View File

@@ -8,7 +8,7 @@ import java.util.List;
@Data
public class EntityInfo {
private ViewInfo viewInfo = new ViewInfo();
private DataSetInfo dataSetInfo = new DataSetInfo();
private List<DataInfo> dimensions = new ArrayList<>();
private List<DataInfo> metrics = new ArrayList<>();
private String entityId;

View File

@@ -65,16 +65,16 @@ public class Agent extends RecordInfo {
.collect(Collectors.toList());
}
public Set<Long> getViewIds() {
return getViewIds(null);
public Set<Long> getDataSetIds() {
return getDataSetIds(null);
}
public Set<Long> getViewIds(AgentToolType agentToolType) {
public Set<Long> getDataSetIds(AgentToolType agentToolType) {
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>();
}
return commonAgentTools.stream().map(NL2SQLTool::getViewIds)
return commonAgentTools.stream().map(NL2SQLTool::getDataSetIds)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
.flatMap(Collection::stream)
.collect(Collectors.toSet());

View File

@@ -12,6 +12,6 @@ import java.util.List;
@AllArgsConstructor
public class NL2SQLTool extends AgentTool {
protected List<Long> viewIds;
protected List<Long> dataSetIds;
}

View File

@@ -15,7 +15,7 @@ public class RuleParserTool extends NL2SQLTool {
private List<String> queryTypes;
public boolean isContainsAllModel() {
return CollectionUtils.isNotEmpty(viewIds) && viewIds.contains(-1L);
return CollectionUtils.isNotEmpty(dataSetIds) && dataSetIds.contains(-1L);
}
}

View File

@@ -45,7 +45,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long viewId) {
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long dataSetId) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
@@ -55,7 +55,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
// support fieldName and field alias
Map<String, String> result = dbAllFields.stream()
.filter(entry -> viewId.equals(entry.getView()))
.filter(entry -> dataSetId.equals(entry.getDataSet()))
.flatMap(schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
@@ -109,8 +109,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
Long viewId = semanticParseInfo.getView().getView();
List<SchemaElement> metrics = getMetricElements(queryContext, viewId);
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
List<SchemaElement> metrics = getMetricElements(queryContext, dataSetId);
Map<String, String> metricToAggregate = metrics.stream()
.map(schemaElement -> {
@@ -135,13 +135,13 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
}
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long viewId) {
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long dataSetId) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
return semanticSchema.getMetrics(viewId);
return semanticSchema.getMetrics(dataSetId);
}
protected Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
protected Set<String> getDimensions(Long dataSetId, SemanticSchema semanticSchema) {
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
.flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();

View File

@@ -8,12 +8,13 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.ViewService;
import com.tencent.supersonic.headless.server.service.DataSetService;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@@ -37,11 +38,12 @@ public class GroupByCorrector extends BaseSemanticCorrector {
}
private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long viewId = semanticParseInfo.getViewId();
ViewService viewService = ContextUtils.getBean(ViewService.class);
Long dataSetId = semanticParseInfo.getDataSetId();
DataSetService dataSetService = ContextUtils.getBean(DataSetService.class);
ModelService modelService = ContextUtils.getBean(ModelService.class);
ViewResp viewResp = viewService.getView(viewId);
List<Long> modelIds = viewResp.getViewDetail().getViewModelConfigs().stream().map(config -> config.getId())
DataSetResp dataSetResp = dataSetService.getDataSet(dataSetId);
List<Long> modelIds = dataSetResp.getDataSetDetail()
.getDataSetModelConfigs().stream().map(DataSetModelConfig::getId)
.collect(Collectors.toList());
MetaFilter metaFilter = new MetaFilter();
metaFilter.setIds(modelIds);
@@ -64,7 +66,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
return false;
}
//add alias field name
Set<String> dimensions = getDimensions(viewId, semanticSchema);
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
return false;
@@ -81,13 +83,13 @@ public class GroupByCorrector extends BaseSemanticCorrector {
}
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long viewId = semanticParseInfo.getViewId();
Long dataSetId = semanticParseInfo.getDataSetId();
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
//add alias field name
Set<String> dimensions = getDimensions(viewId, semanticSchema);
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
Set<String> groupByFields = selectFields.stream()

View File

@@ -39,11 +39,11 @@ public class HavingCorrector extends BaseSemanticCorrector {
}
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Long viewId = semanticParseInfo.getView().getView();
Long dataSet = semanticParseInfo.getDataSet().getDataSet();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
if (CollectionUtils.isEmpty(metrics)) {

View File

@@ -16,15 +16,16 @@ import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
/**
* Perform schema corrections on the Schema information in S2SQL.
@@ -62,7 +63,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
}
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getViewId());
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getDataSetId());
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
sqlInfo.setCorrectS2SQL(sql);
@@ -125,7 +126,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
}
List<ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Set<String> dimensions = getDimensions(semanticParseInfo.getViewId(), semanticSchema);
Set<String> dimensions = getDimensions(semanticParseInfo.getDataSetId(), semanticSchema);
if (CollectionUtils.isEmpty(linkingValues)) {
linkingValues = new ArrayList<>();

View File

@@ -67,7 +67,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
semanticParseInfo.getViewId(), semanticParseInfo.getQueryType());
semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType());
if (StringUtils.isNotBlank(startEndDate.getLeft())
&& StringUtils.isNotBlank(startEndDate.getRight())) {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
@@ -101,8 +101,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
Long viewId = semanticParseInfo.getViewId();
List<SchemaElement> dimensions = semanticSchema.getDimensions(viewId);
Long dataSetId = semanticParseInfo.getDataSetId();
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
if (CollectionUtils.isEmpty(dimensions)) {
return;

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@@ -26,7 +26,7 @@ public abstract class BaseMapper implements SchemaMapper {
String simpleName = this.getClass().getSimpleName();
long startTime = System.currentTimeMillis();
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getViewElementMatches());
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getDataSetElementMatches());
try {
doMap(queryContext);
@@ -35,13 +35,14 @@ public abstract class BaseMapper implements SchemaMapper {
}
long cost = System.currentTimeMillis() - startTime;
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getViewElementMatches());
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost,
queryContext.getMapInfo().getDataSetElementMatches());
}
public abstract void doMap(QueryContext queryContext);
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getViewElementMatches();
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getDataSetElementMatches();
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
if (schemaElementMatches == null) {
schemaElementMatches = modelElementMatches.get(modelId);
@@ -67,14 +68,14 @@ public abstract class BaseMapper implements SchemaMapper {
}
}
public SchemaElement getSchemaElement(Long viewId, SchemaElementType elementType, Long elementID,
public SchemaElement getSchemaElement(Long dataSetId, SchemaElementType elementType, Long elementID,
SemanticSchema semanticSchema) {
SchemaElement element = new SchemaElement();
ViewSchema viewSchema = semanticSchema.getViewSchemaMap().get(viewId);
if (Objects.isNull(viewSchema)) {
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
if (Objects.isNull(dataSetSchema)) {
return null;
}
SchemaElement elementDb = viewSchema.getElement(elementType, elementID);
SchemaElement elementDb = dataSetSchema.getElement(elementType, elementID);
if (Objects.isNull(elementDb)) {
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
return null;

View File

@@ -28,22 +28,22 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
@Override
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectViewIds) {
Set<Long> detectDataSetIds) {
String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
log.debug("terms:{},,detectViewIds:{}", terms, detectViewIds);
log.debug("terms:{},,detectDataSetIds:{}", terms, detectDataSetIds);
List<T> detects = detect(queryContext, terms, detectViewIds);
List<T> detects = detect(queryContext, terms, detectDataSetIds);
Map<MatchText, List<T>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
return result;
}
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds) {
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
String text = queryContext.getQueryText();
Set<T> results = new HashSet<>();
@@ -58,16 +58,16 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
if (index <= text.length()) {
String detectSegment = text.substring(startIndex, index).trim();
detectSegments.add(detectSegment);
detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
detectByStep(queryContext, results, detectDataSetIds, detectSegment, offset);
}
}
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
}
detectByBatch(queryContext, results, detectViewIds, detectSegments);
detectByBatch(queryContext, results, detectDataSetIds, detectSegments);
return new ArrayList<>(results);
}
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectViewIds,
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectDataSetIds,
Set<String> detectSegments) {
return;
}
@@ -104,9 +104,9 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
}
public List<T> getMatches(QueryContext queryContext, List<S2Term> terms) {
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
terms = filterByViewId(terms, viewIds);
Map<MatchText, List<T>> matchResult = match(queryContext, terms, viewIds);
Set<Long> dataSetIds = mapperHelper.getDataSetIds(queryContext.getDataSetId(), queryContext.getAgent());
terms = filterByDataSetId(terms, dataSetIds);
Map<MatchText, List<T>> matchResult = match(queryContext, terms, dataSetIds);
List<T> matches = new ArrayList<>();
if (Objects.isNull(matchResult)) {
return matches;
@@ -121,17 +121,17 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
return matches;
}
public List<S2Term> filterByViewId(List<S2Term> terms, Set<Long> viewIds) {
public List<S2Term> filterByDataSetId(List<S2Term> terms, Set<Long> dataSetIds) {
logTerms(terms);
if (CollectionUtils.isNotEmpty(viewIds)) {
if (CollectionUtils.isNotEmpty(dataSetIds)) {
terms = terms.stream().filter(term -> {
Long viewId = NatureHelper.getViewId(term.getNature().toString());
if (Objects.nonNull(viewId)) {
return viewIds.contains(viewId);
Long dataSetId = NatureHelper.getDataSetId(term.getNature().toString());
if (Objects.nonNull(dataSetId)) {
return dataSetIds.contains(dataSetId);
}
return false;
}).collect(Collectors.toList());
log.info("terms filter by viewId:{}", viewIds);
log.info("terms filter by dataSetId:{}", dataSetIds);
logTerms(terms);
}
return terms;
@@ -150,7 +150,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
public abstract String getMapKey(T a);
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset);
}

View File

@@ -37,9 +37,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
@Override
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectViewIds) {
Set<Long> detectDataSetIds) {
this.allElements = getSchemaElements(queryContext);
return super.match(queryContext, terms, detectViewIds);
return super.match(queryContext, terms, detectDataSetIds);
}
@Override
@@ -54,7 +54,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
+ Constants.UNDERLINE + a.getSchemaElement().getName();
}
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
if (StringUtils.isBlank(detectSegment)) {
return;
@@ -70,9 +70,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
continue;
}
Set<SchemaElement> schemaElements = entry.getValue();
if (!CollectionUtils.isEmpty(detectViewIds)) {
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
schemaElements = schemaElements.stream()
.filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
.filter(schemaElement -> detectDataSetIds.contains(schemaElement.getDataSet()))
.collect(Collectors.toSet());
}
for (SchemaElement schemaElement : schemaElements) {
@@ -96,7 +96,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getViewElementMatches();
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches();
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);

View File

@@ -36,12 +36,12 @@ public class EmbeddingMapper extends BaseMapper {
//2. build SchemaElementMatch by info
for (EmbeddingResult matchResult : matchResults) {
Long elementId = Retrieval.getLongId(matchResult.getId());
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
if (Objects.isNull(viewId)) {
Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId"));
if (Objects.isNull(dataSetId)) {
continue;
}
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
queryContext.getSemanticSchema());
if (schemaElement == null) {
continue;
@@ -54,7 +54,7 @@ public class EmbeddingMapper extends BaseMapper {
.detectWord(matchResult.getDetectWord())
.build();
//3. add to mapInfo
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
}
}
}

View File

@@ -48,13 +48,13 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
}
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
}
@Override
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
Set<String> detectSegments) {
List<String> queryTextsList = detectSegments.stream()
@@ -68,11 +68,11 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
optimizationConfig.getEmbeddingMapperBatch());
for (List<String> queryTextsSub : queryTextsSubList) {
detectByQueryTextsSub(results, detectViewIds, queryTextsSub);
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub);
}
}
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectViewIds,
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
List<String> queryTextsSub) {
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
@@ -80,7 +80,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
// step2. retrieveQuery by detectSegment
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
new ArrayList<>(detectViewIds), retrieveQuery, embeddingNumber);
new ArrayList<>(detectDataSetIds), retrieveQuery, embeddingNumber);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
@@ -24,12 +24,12 @@ public class EntityMapper extends BaseMapper {
@Override
public void doMap(QueryContext queryContext) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
for (Long viewId : schemaMapInfo.getMatchedViewInfos()) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(viewId);
for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
continue;
}
SchemaElement entity = getEntity(viewId, queryContext);
SchemaElement entity = getEntity(dataSetId, queryContext);
if (entity == null || entity.getId() == null) {
continue;
}
@@ -65,9 +65,9 @@ public class EntityMapper extends BaseMapper {
return false;
}
private SchemaElement getEntity(Long viewId, QueryContext queryContext) {
private SchemaElement getEntity(Long dataSetId, QueryContext queryContext) {
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
ViewSchema modelSchema = semanticSchema.getViewSchemaMap().get(viewId);
DataSetSchema modelSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
if (modelSchema != null && modelSchema.getEntity() != null) {
return modelSchema.getEntity();
}

View File

@@ -39,15 +39,15 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectViewIds) {
Set<Long> detectDataSetIds) {
String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectViewIds);
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectDataSetIds);
List<HanlpMapResult> detects = detect(queryContext, terms, detectViewIds);
List<HanlpMapResult> detects = detect(queryContext, terms, detectDataSetIds);
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
@@ -60,15 +60,15 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
}
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
oneDetectionMaxSize, detectDataSetIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(detectSegment,
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
oneDetectionMaxSize, detectDataSetIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
hanlpMapResults.addAll(suffixHanlpMapResults);

View File

@@ -59,8 +59,8 @@ public class KeywordMapper extends BaseMapper {
for (HanlpMapResult hanlpMapResult : mapResults) {
for (String nature : hanlpMapResult.getNatures()) {
Long viewId = NatureHelper.getViewId(nature);
if (Objects.isNull(viewId)) {
Long dataSetId = NatureHelper.getDataSetId(nature);
if (Objects.isNull(dataSetId)) {
continue;
}
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
@@ -68,7 +68,7 @@ public class KeywordMapper extends BaseMapper {
continue;
}
Long elementID = NatureHelper.getElementID(nature);
SchemaElement element = getSchemaElement(viewId, elementType,
SchemaElement element = getSchemaElement(dataSetId, elementType,
elementID, queryContext.getSemanticSchema());
if (element == null) {
continue;
@@ -85,7 +85,7 @@ public class KeywordMapper extends BaseMapper {
.detectWord(hanlpMapResult.getDetectWord())
.build();
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
}
}
}
@@ -106,12 +106,12 @@ public class KeywordMapper extends BaseMapper {
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
.build();
log.info("add to schema, elementMatch {}", schemaElementMatch);
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getView(), schemaElementMatch);
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getDataSet(), schemaElementMatch);
}
}
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getView());
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDataSet());
if (CollectionUtils.isEmpty(elements)) {
return new HashSet<>();
}

View File

@@ -62,7 +62,7 @@ public class MapperHelper {
*/
public boolean existDimensionValues(List<String> natures) {
for (String nature : natures) {
if (NatureHelper.isDimensionValueViewId(nature)) {
if (NatureHelper.isDimensionValueDataSetId(nature)) {
return true;
}
}
@@ -82,33 +82,33 @@ public class MapperHelper {
detectSegment.length());
}
public Set<Long> getViewIds(Long viewId, Agent agent) {
public Set<Long> getDataSetIds(Long dataSetId, Agent agent) {
Set<Long> detectViewIds = new HashSet<>();
Set<Long> detectDataSetIds = new HashSet<>();
if (Objects.nonNull(agent)) {
detectViewIds = agent.getViewIds(null);
detectDataSetIds = agent.getDataSetIds();
}
//contains all
if (Agent.containsAllModel(detectViewIds)) {
if (Objects.nonNull(viewId) && viewId > 0) {
if (Agent.containsAllModel(detectDataSetIds)) {
if (Objects.nonNull(dataSetId) && dataSetId > 0) {
Set<Long> result = new HashSet<>();
result.add(viewId);
result.add(dataSetId);
return result;
}
return new HashSet<>();
}
if (Objects.nonNull(detectViewIds)) {
detectViewIds = detectViewIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
if (Objects.nonNull(detectDataSetIds)) {
detectDataSetIds = detectDataSetIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
}
if (Objects.nonNull(viewId) && viewId > 0 && Objects.nonNull(detectViewIds)) {
if (detectViewIds.contains(viewId)) {
if (Objects.nonNull(dataSetId) && dataSetId > 0 && Objects.nonNull(detectDataSetIds)) {
if (detectDataSetIds.contains(dataSetId)) {
Set<Long> result = new HashSet<>();
result.add(viewId);
result.add(dataSetId);
return result;
}
}
return detectViewIds;
return detectDataSetIds;
}
}

View File

@@ -13,6 +13,6 @@ import java.util.Set;
*/
public interface MatchStrategy<T> {
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds);
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
}

View File

@@ -27,13 +27,13 @@ public class QueryFilterMapper implements SchemaMapper {
@Override
public void map(QueryContext queryContext) {
Agent agent = queryContext.getAgent();
if (agent == null || CollectionUtils.isEmpty(agent.getViewIds())) {
if (agent == null || CollectionUtils.isEmpty(agent.getDataSetIds())) {
return;
}
if (Agent.containsAllModel(agent.getViewIds())) {
if (Agent.containsAllModel(agent.getDataSetIds())) {
return;
}
Set<Long> viewIds = agent.getViewIds();
Set<Long> viewIds = agent.getDataSetIds();
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
clearOtherSchemaElementMatch(viewIds, schemaMapInfo);
for (Long viewId : viewIds) {
@@ -47,7 +47,7 @@ public class QueryFilterMapper implements SchemaMapper {
}
private void clearOtherSchemaElementMatch(Set<Long> viewIds, SchemaMapInfo schemaMapInfo) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getDataSetElementMatches().entrySet()) {
if (!viewIds.contains(entry.getKey())) {
entry.getValue().clear();
}
@@ -69,7 +69,7 @@ public class QueryFilterMapper implements SchemaMapper {
.name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE)
.bizName(filter.getBizName())
.view(viewId)
.dataSet(viewId)
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)

View File

@@ -32,7 +32,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
Set<Long> detectViewIds) {
Set<Long> detectDataSetIds) {
String text = queryContext.getQueryText();
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
@@ -57,9 +57,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
if (StringUtils.isNotEmpty(detectSegment)) {
List<HanlpMapResult> hanlpMapResults = knowledgeService.prefixSearch(detectSegment,
SearchService.SEARCH_SIZE, detectViewIds);
SearchService.SEARCH_SIZE, detectDataSetIds);
List<HanlpMapResult> suffixHanlpMapResults = knowledgeService.suffixSearch(
detectSegment, SEARCH_SIZE, detectViewIds);
detectSegment, SEARCH_SIZE, detectDataSetIds);
hanlpMapResults.addAll(suffixHanlpMapResults);
// remove entity name where search
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
@@ -93,7 +93,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
}
@Override
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) {
}

View File

@@ -38,12 +38,12 @@ public class JavaLLMProxy implements LLMProxy {
return false;
}
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
public LLMResp query2sql(LLMReq llmReq, Long dataSetId) {
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
String modelName = llmReq.getSchema().getViewName();
LLMResp result = sqlGeneration.generation(llmReq, viewId);
String modelName = llmReq.getSchema().getDataSetName();
LLMResp result = sqlGeneration.generation(llmReq, dataSetId);
result.setQuery(llmReq.getQueryText());
result.setModelName(modelName);
return result;

View File

@@ -15,7 +15,7 @@ public interface LLMProxy {
boolean isSkip(QueryContext queryContext);
LLMResp query2sql(LLMReq llmReq, Long viewId);
LLMResp query2sql(LLMReq llmReq, Long dataSetId);
FunctionResp requestFunction(FunctionReq functionReq);

View File

@@ -48,10 +48,10 @@ public class PythonLLMProxy implements LLMProxy {
return false;
}
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
public LLMResp query2sql(LLMReq llmReq, Long dataSetId) {
long startTime = System.currentTimeMillis();
log.info("requestLLM request, viewId:{},llmReq:{}", viewId, llmReq);
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
log.info("requestLLM request, dataSetId:{},llmReq:{}", dataSetId, llmReq);
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
try {
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);

View File

@@ -50,7 +50,7 @@ public class QueryTypeParser implements SemanticParser {
return QueryType.ID;
}
//1. entity queryType
Long viewId = parseInfo.getViewId();
Long dataSetId = parseInfo.getDataSetId();
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
//If all the fields in the SELECT statement are of tag type.
@@ -59,12 +59,12 @@ public class QueryTypeParser implements SemanticParser {
.collect(Collectors.toList());
if (CollectionUtils.isNotEmpty(whereFields)) {
Set<String> ids = semanticSchema.getEntities(viewId).stream().map(SchemaElement::getName)
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
return QueryType.ID;
}
Set<String> tags = semanticSchema.getTags(viewId).stream().map(SchemaElement::getName)
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
return QueryType.TAG;
@@ -73,7 +73,7 @@ public class QueryTypeParser implements SemanticParser {
}
//2. metric queryType
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
List<SchemaElement> metrics = semanticSchema.getMetrics(viewId);
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);

View File

@@ -55,13 +55,13 @@ public abstract class PluginParser implements SemanticParser {
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
Plugin plugin = pluginRecallResult.getPlugin();
Set<Long> viewIds = pluginRecallResult.getViewIds();
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
if (plugin.isContainsAllModel()) {
viewIds = Sets.newHashSet(-1L);
dataSetIds = Sets.newHashSet(-1L);
}
for (Long viewId : viewIds) {
for (Long dataSetId : dataSetIds) {
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(viewId, plugin,
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
queryContext, pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
semanticParseInfo.setScore(pluginRecallResult.getScore());
@@ -74,19 +74,19 @@ public abstract class PluginParser implements SemanticParser {
return PluginManager.getPluginAgentCanSupport(queryContext);
}
protected SemanticParseInfo buildSemanticParseInfo(Long viewId, Plugin plugin,
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, Plugin plugin,
QueryContext queryContext, double distance) {
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(dataSetId);
QueryFilters queryFilters = queryContext.getQueryFilters();
if (viewId == null && !CollectionUtils.isEmpty(plugin.getViewList())) {
viewId = plugin.getViewList().get(0);
if (dataSetId == null && !CollectionUtils.isEmpty(plugin.getDataSetList())) {
dataSetId = plugin.getDataSetList().get(0);
}
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
}
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setView(queryContext.getSemanticSchema().getView(viewId));
semanticParseInfo.setDataSet(queryContext.getSemanticSchema().getDataSet(dataSetId));
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);

View File

@@ -57,15 +57,15 @@ public class EmbeddingRecallParser extends PluginParser {
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
log.info("embedding plugin resolve: {}", pair);
if (pair.getLeft()) {
Set<Long> viewList = pair.getRight();
if (CollectionUtils.isEmpty(viewList)) {
Set<Long> dataSetList = pair.getRight();
if (CollectionUtils.isEmpty(dataSetList)) {
continue;
}
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double distance = embeddingRetrieval.getDistance();
double score = queryContext.getQueryText().length() * (1 - distance);
return PluginRecallResult.builder()
.plugin(plugin).viewIds(viewList).score(score).distance(distance).build();
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
}
}
return null;

View File

@@ -57,19 +57,19 @@ public class FunctionCallParser extends PluginParser {
plugin.setParseMode(ParseMode.FUNCTION_CALL);
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
if (pluginResolveResult.getLeft()) {
Set<Long> viewList = pluginResolveResult.getRight();
if (CollectionUtils.isEmpty(viewList)) {
Set<Long> dataSetList = pluginResolveResult.getRight();
if (CollectionUtils.isEmpty(dataSetList)) {
return null;
}
double score = queryContext.getQueryText().length();
return PluginRecallResult.builder().plugin(plugin).viewIds(viewList).score(score).build();
return PluginRecallResult.builder().plugin(plugin).dataSetIds(dataSetList).score(score).build();
}
return null;
}
public FunctionResp functionCall(QueryContext queryContext) {
List<PluginParseConfig> pluginToFunctionCall =
getPluginToFunctionCall(queryContext.getViewId(), queryContext);
getPluginToFunctionCall(queryContext.getDataSetId(), queryContext);
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
log.info("function call parser, plugin is empty, skip");
return null;

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
import lombok.Data;
@Data
public class ViewMatchResult {
public class DataSetMatchResult {
private Integer count = 0;
private double maxSimilarity;
}

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.chat.core.pojo.QueryContext;
import java.util.Set;
public interface ViewResolver {
public interface DataSetResolver {
Long resolve(QueryContext queryContext, Set<Long> restrictiveModels);

View File

@@ -0,0 +1,138 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
public class HeuristicDataSetResolver implements DataSetResolver {
protected static Long selectDataSetBySchemaElementMatchScore(Map<Long, SemanticQuery> dataSetQueryModes,
SchemaMapInfo schemaMap) {
//dataSet count priority
Long dataSetIdByDataSetCount = getDataSetIdByMatchDataSetScore(schemaMap);
if (Objects.nonNull(dataSetIdByDataSetCount)) {
log.info("selectDataSet by dataSet count:{}", dataSetIdByDataSetCount);
return dataSetIdByDataSetCount;
}
Map<Long, DataSetMatchResult> dataSetTypeMap = getDataSetTypeMap(schemaMap);
if (dataSetTypeMap.size() == 1) {
Long dataSetSelect = new ArrayList<>(dataSetTypeMap.entrySet()).get(0).getKey();
if (dataSetQueryModes.containsKey(dataSetSelect)) {
log.info("selectDataSet with only one DataSet [{}]", dataSetSelect);
return dataSetSelect;
}
} else {
Map.Entry<Long, DataSetMatchResult> maxDataSet = dataSetTypeMap.entrySet().stream()
.filter(entry -> dataSetQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> {
int difference = o2.getValue().getCount() - o1.getValue().getCount();
if (difference == 0) {
return (int) ((o2.getValue().getMaxSimilarity()
- o1.getValue().getMaxSimilarity()) * 100);
}
return difference;
}).findFirst().orElse(null);
if (maxDataSet != null) {
log.info("selectDataSet with multiple DataSets [{}]", maxDataSet.getKey());
return maxDataSet.getKey();
}
}
return null;
}
private static Long getDataSetIdByMatchDataSetScore(SchemaMapInfo schemaMap) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = schemaMap.getDataSetElementMatches();
// calculate dataSet match score, matched element gets 1.0 point, and inherit element gets 0.5 point
Map<Long, Double> dataSetIdToDataSetScore = new HashMap<>();
if (Objects.nonNull(dataSetElementMatches)) {
for (Entry<Long, List<SchemaElementMatch>> dataSetElementMatch : dataSetElementMatches.entrySet()) {
Long dataSetId = dataSetElementMatch.getKey();
List<Double> dataSetMatchesScore = dataSetElementMatch.getValue().stream()
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
.filter(elementMatch -> SchemaElementType.DATASET.equals(elementMatch.getElement().getType()))
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
if (!CollectionUtils.isEmpty(dataSetMatchesScore)) {
// get sum of dataSet match score
double score = dataSetMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
dataSetIdToDataSetScore.put(dataSetId, score);
}
}
Entry<Long, Double> maxDataSetScore = dataSetIdToDataSetScore.entrySet().stream()
.max(Comparator.comparingDouble(Entry::getValue)).orElse(null);
log.info("maxDataSetCount:{},dataSetIdToDataSetCount:{}", maxDataSetScore, dataSetIdToDataSetScore);
if (Objects.nonNull(maxDataSetScore)) {
return maxDataSetScore.getKey();
}
}
return null;
}
public static Map<Long, DataSetMatchResult> getDataSetTypeMap(SchemaMapInfo schemaMap) {
Map<Long, DataSetMatchResult> dataSetCount = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getDataSetElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!dataSetCount.containsKey(entry.getKey())) {
dataSetCount.put(entry.getKey(), new DataSetMatchResult());
}
DataSetMatchResult dataSetMatchResult = dataSetCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream()
.forEach(schemaElementMatch -> schemaElementTypes.add(
schemaElementMatch.getElement().getType()));
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
.sorted((o1, o2) ->
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
).findFirst().orElse(null);
if (schemaElementMatchMax != null) {
dataSetMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
}
dataSetMatchResult.setCount(schemaElementTypes.size());
}
}
return dataSetCount;
}
public Long resolve(QueryContext queryContext, Set<Long> agentDataSetIds) {
SchemaMapInfo mapInfo = queryContext.getMapInfo();
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
Long dataSetId = queryContext.getDataSetId();
if (Objects.nonNull(dataSetId) && dataSetId > 0) {
if (CollectionUtils.isEmpty(agentDataSetIds) || agentDataSetIds.contains(dataSetId)) {
return dataSetId;
}
return null;
}
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
matchedDataSets.retainAll(agentDataSetIds);
}
Map<Long, SemanticQuery> dataSetQueryModes = new HashMap<>();
for (Long dataSetIds : matchedDataSets) {
dataSetQueryModes.put(dataSetIds, null);
}
if (dataSetQueryModes.size() == 1) {
return dataSetQueryModes.keySet().stream().findFirst().get();
}
return selectDataSetBySchemaElementMatchScore(dataSetQueryModes, mapInfo);
}
}

View File

@@ -1,138 +0,0 @@
package com.tencent.supersonic.chat.core.parser.sql.llm;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
public class HeuristicViewResolver implements ViewResolver {
protected static Long selectViewBySchemaElementMatchScore(Map<Long, SemanticQuery> viewQueryModes,
SchemaMapInfo schemaMap) {
//view count priority
Long viewIdByViewCount = getViewIdByMatchViewScore(schemaMap);
if (Objects.nonNull(viewIdByViewCount)) {
log.info("selectView by view count:{}", viewIdByViewCount);
return viewIdByViewCount;
}
Map<Long, ViewMatchResult> viewTypeMap = getViewTypeMap(schemaMap);
if (viewTypeMap.size() == 1) {
Long viewSelect = new ArrayList<>(viewTypeMap.entrySet()).get(0).getKey();
if (viewQueryModes.containsKey(viewSelect)) {
log.info("selectView with only one View [{}]", viewSelect);
return viewSelect;
}
} else {
Map.Entry<Long, ViewMatchResult> maxView = viewTypeMap.entrySet().stream()
.filter(entry -> viewQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> {
int difference = o2.getValue().getCount() - o1.getValue().getCount();
if (difference == 0) {
return (int) ((o2.getValue().getMaxSimilarity()
- o1.getValue().getMaxSimilarity()) * 100);
}
return difference;
}).findFirst().orElse(null);
if (maxView != null) {
log.info("selectView with multiple Views [{}]", maxView.getKey());
return maxView.getKey();
}
}
return null;
}
private static Long getViewIdByMatchViewScore(SchemaMapInfo schemaMap) {
Map<Long, List<SchemaElementMatch>> viewElementMatches = schemaMap.getViewElementMatches();
// calculate view match score, matched element gets 1.0 point, and inherit element gets 0.5 point
Map<Long, Double> viewIdToViewScore = new HashMap<>();
if (Objects.nonNull(viewElementMatches)) {
for (Entry<Long, List<SchemaElementMatch>> viewElementMatch : viewElementMatches.entrySet()) {
Long viewId = viewElementMatch.getKey();
List<Double> viewMatchesScore = viewElementMatch.getValue().stream()
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
.filter(elementMatch -> SchemaElementType.VIEW.equals(elementMatch.getElement().getType()))
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
if (!CollectionUtils.isEmpty(viewMatchesScore)) {
// get sum of view match score
double score = viewMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
viewIdToViewScore.put(viewId, score);
}
}
Entry<Long, Double> maxViewScore = viewIdToViewScore.entrySet().stream()
.max(Comparator.comparingDouble(Entry::getValue)).orElse(null);
log.info("maxViewCount:{},viewIdToViewCount:{}", maxViewScore, viewIdToViewScore);
if (Objects.nonNull(maxViewScore)) {
return maxViewScore.getKey();
}
}
return null;
}
public static Map<Long, ViewMatchResult> getViewTypeMap(SchemaMapInfo schemaMap) {
Map<Long, ViewMatchResult> viewCount = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getViewElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!viewCount.containsKey(entry.getKey())) {
viewCount.put(entry.getKey(), new ViewMatchResult());
}
ViewMatchResult viewMatchResult = viewCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream()
.forEach(schemaElementMatch -> schemaElementTypes.add(
schemaElementMatch.getElement().getType()));
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
.sorted((o1, o2) ->
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
).findFirst().orElse(null);
if (schemaElementMatchMax != null) {
viewMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
}
viewMatchResult.setCount(schemaElementTypes.size());
}
}
return viewCount;
}
public Long resolve(QueryContext queryContext, Set<Long> agentViewIds) {
SchemaMapInfo mapInfo = queryContext.getMapInfo();
Set<Long> matchedViews = mapInfo.getMatchedViewInfos();
Long viewId = queryContext.getViewId();
if (Objects.nonNull(viewId) && viewId > 0) {
if (CollectionUtils.isEmpty(agentViewIds) || agentViewIds.contains(viewId)) {
return viewId;
}
return null;
}
if (CollectionUtils.isNotEmpty(agentViewIds)) {
matchedViews.retainAll(agentViewIds);
}
Map<Long, SemanticQuery> viewQueryModes = new HashMap<>();
for (Long viewIds : matchedViews) {
viewQueryModes.put(viewIds, null);
}
if (viewQueryModes.size() == 1) {
return viewQueryModes.keySet().stream().findFirst().get();
}
return selectViewBySchemaElementMatchScore(viewQueryModes, mapInfo);
}
}

View File

@@ -22,7 +22,7 @@ import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
@@ -61,20 +61,20 @@ public class LLMRequestService {
return false;
}
public Long getViewId(QueryContext queryCtx) {
public Long getDataSetId(QueryContext queryCtx) {
Agent agent = queryCtx.getAgent();
Set<Long> agentViewIds = new HashSet<>();
Set<Long> agentDataSetIds = new HashSet<>();
if (Objects.nonNull(agent)) {
agentViewIds = agent.getViewIds(AgentToolType.NL2SQL_LLM);
agentDataSetIds = agent.getDataSetIds(AgentToolType.NL2SQL_LLM);
}
if (Agent.containsAllModel(agentViewIds)) {
agentViewIds = new HashSet<>();
if (Agent.containsAllModel(agentDataSetIds)) {
agentDataSetIds = new HashSet<>();
}
ViewResolver viewResolver = ComponentFactory.getModelResolver();
return viewResolver.resolve(queryCtx, agentViewIds);
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
return dataSetResolver.resolve(queryCtx, agentDataSetIds);
}
public NL2SQLTool getParserTool(QueryContext queryCtx, Long viewId) {
public NL2SQLTool getParserTool(QueryContext queryCtx, Long dataSetId) {
Agent agent = queryCtx.getAgent();
if (Objects.isNull(agent)) {
return null;
@@ -82,19 +82,19 @@ public class LLMRequestService {
List<NL2SQLTool> commonAgentTools = agent.getParserTools(AgentToolType.NL2SQL_LLM);
Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
.filter(tool -> {
List<Long> viewIds = tool.getViewIds();
if (Agent.containsAllModel(new HashSet<>(viewIds))) {
List<Long> dataSetIds = tool.getDataSetIds();
if (Agent.containsAllModel(new HashSet<>(dataSetIds))) {
return true;
}
return viewIds.contains(viewId);
return dataSetIds.contains(dataSetId);
})
.findFirst();
return llmParserTool.orElse(null);
}
public LLMReq getLlmReq(QueryContext queryCtx, Long viewId,
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
SemanticSchema semanticSchema, List<ElementValue> linkingValues) {
Map<Long, String> viewIdToName = semanticSchema.getViewIdToName();
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
String queryText = queryCtx.getQueryText();
LLMReq llmReq = new LLMReq();
@@ -103,12 +103,12 @@ public class LLMRequestService {
llmReq.setFilterCondition(filterCondition);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmSchema.setViewName(viewIdToName.get(viewId));
llmSchema.setDomainName(viewIdToName.get(viewId));
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
List<String> fieldNameList = getFieldNameList(queryCtx, viewId, llmParserConfig);
List<String> fieldNameList = getFieldNameList(queryCtx, dataSetId, llmParserConfig);
String priorExts = getPriorExts(viewId, fieldNameList);
String priorExts = getPriorExts(dataSetId, fieldNameList);
llmReq.setPriorExts(priorExts);
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
@@ -121,7 +121,7 @@ public class LLMRequestService {
}
llmReq.setLinking(linking);
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, viewId);
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, dataSetId);
if (StringUtils.isEmpty(currentDate)) {
currentDate = DateUtils.getBeforeDate(0);
}
@@ -130,28 +130,28 @@ public class LLMRequestService {
return llmReq;
}
public LLMResp requestLLM(LLMReq llmReq, Long viewId) {
return ComponentFactory.getLLMProxy().query2sql(llmReq, viewId);
public LLMResp requestLLM(LLMReq llmReq, Long dataSetId) {
return ComponentFactory.getLLMProxy().query2sql(llmReq, dataSetId);
}
protected List<String> getFieldNameList(QueryContext queryCtx, Long viewId,
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(queryCtx, viewId, llmParserConfig);
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, viewId);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
results.addAll(fieldNameList);
return new ArrayList<>(results);
}
private String getPriorExts(Long viewId, List<String> fieldNameList) {
private String getPriorExts(Long dataSetId, List<String> fieldNameList) {
StringBuilder extraInfoSb = new StringBuilder();
List<ViewSchemaResp> viewSchemaResps = semanticInterpreter.fetchViewSchema(
Lists.newArrayList(viewId), true);
if (!CollectionUtils.isEmpty(viewSchemaResps)) {
ViewSchemaResp viewSchemaResp = viewSchemaResps.get(0);
Map<String, String> fieldNameToDataFormatType = viewSchemaResp.getMetrics()
List<DataSetSchemaResp> dataSetSchemaResps = semanticInterpreter.fetchDataSetSchema(
Lists.newArrayList(dataSetId), true);
if (!CollectionUtils.isEmpty(dataSetSchemaResps)) {
DataSetSchemaResp dataSetSchemaResp = dataSetSchemaResps.get(0);
Map<String, String> fieldNameToDataFormatType = dataSetSchemaResp.getMetrics()
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
.flatMap(metricSchemaResp -> {
Set<Pair<String, String>> result = new HashSet<>();
@@ -179,9 +179,9 @@ public class LLMRequestService {
return extraInfoSb.toString();
}
protected List<ElementValue> getValueList(QueryContext queryCtx, Long viewId) {
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, viewId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(viewId);
protected List<ElementValue> getValueList(QueryContext queryCtx, Long dataSetId) {
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
}
@@ -201,21 +201,21 @@ public class LLMRequestService {
return new ArrayList<>(valueMatches);
}
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long viewId) {
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long dataSetId) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
return semanticSchema.getDimensions(viewId).stream()
return semanticSchema.getDimensions(dataSetId).stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long viewId, LLMParserConfig llmParserConfig) {
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
Set<String> results = semanticSchema.getDimensions(viewId).stream()
Set<String> results = semanticSchema.getDimensions(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
@@ -225,9 +225,9 @@ public class LLMRequestService {
return results;
}
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long viewId) {
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, viewId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(viewId);
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long dataSetId) {
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new HashSet<>();
}

View File

@@ -28,9 +28,9 @@ public class LLMResponseService {
}
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(LLMSqlQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
parseInfo.setView(queryCtx.getSemanticSchema().getView(parseResult.getViewId()));
parseInfo.setDataSet(queryCtx.getSemanticSchema().getDataSet(parseResult.getDataSetId()));
NL2SQLTool commonAgentTool = parseResult.getCommonAgentTool();
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getViewId()));
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getDataSetId()));
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, parseResult);

View File

@@ -29,21 +29,21 @@ public class LLMSqlParser implements SemanticParser {
}
try {
//2.get modelId from queryCtx and chatCtx.
Long viewId = requestService.getViewId(queryCtx);
if (viewId == null) {
Long dataSetId = requestService.getDataSetId(queryCtx);
if (dataSetId == null) {
return;
}
//3.get agent tool and determine whether to skip this parser.
NL2SQLTool commonAgentTool = requestService.getParserTool(queryCtx, viewId);
NL2SQLTool commonAgentTool = requestService.getParserTool(queryCtx, dataSetId);
if (Objects.isNull(commonAgentTool)) {
log.info("no tool in this agent, skip {}", LLMSqlParser.class);
return;
}
//4.construct a request, call the API for the large model, and retrieve the results.
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, viewId);
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, dataSetId);
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
LLMReq llmReq = requestService.getLlmReq(queryCtx, viewId, semanticSchema, linkingValues);
LLMResp llmResp = requestService.requestLLM(llmReq, viewId);
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
LLMResp llmResp = requestService.requestLLM(llmReq, dataSetId);
if (Objects.isNull(llmResp)) {
return;
@@ -52,7 +52,7 @@ public class LLMSqlParser implements SemanticParser {
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
ParseResult parseResult = ParseResult.builder()
.viewId(viewId)
.dataSetId(dataSetId)
.commonAgentTool(commonAgentTool)
.llmReq(llmReq)
.llmResp(llmResp)

View File

@@ -43,9 +43,9 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator;
@Override
public LLMResp generation(LLMReq llmReq, Long viewId) {
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());

View File

@@ -41,9 +41,9 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator;
@Override
public LLMResp generation(LLMReq llmReq, Long viewId) {
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
//1.retriever sqlExamples
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());

View File

@@ -18,7 +18,7 @@ import java.util.List;
@NoArgsConstructor
public class ParseResult {
private Long viewId;
private Long dataSetId;
private LLMReq llmReq;

View File

@@ -12,9 +12,9 @@ public interface SqlGeneration {
/***
* generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
* @param llmReq
* @param viewId
* @param dataSetId
* @return
*/
LLMResp generation(LLMReq llmReq, Long viewId);
LLMResp generation(LLMReq llmReq, Long dataSetId);
}

View File

@@ -96,7 +96,7 @@ public class SqlPromptGenerator {
}
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
String modelName = llmReq.getSchema().getViewName();
String modelName = llmReq.getSchema().getDataSetName();
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
List<ElementValue> linking = llmReq.getLinking();
String currentDate = llmReq.getCurrentDate();

View File

@@ -40,9 +40,9 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator;
@Override
public LLMResp generation(LLMReq llmReq, Long viewId) {
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());

View File

@@ -40,8 +40,8 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
private SqlPromptGenerator sqlPromptGenerator;
@Override
public LLMResp generation(LLMReq llmReq, Long viewId) {
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());

View File

@@ -58,13 +58,13 @@ public class AgentCheckParser implements SemanticParser {
}
}
}
if (CollectionUtils.isEmpty(tool.getViewIds())) {
if (CollectionUtils.isEmpty(tool.getDataSetIds())) {
return true;
}
if (tool.isContainsAllModel()) {
return false;
}
return !tool.getViewIds().contains(query.getParseInfo().getViewId());
return !tool.getDataSetIds().contains(query.getParseInfo().getDataSetId());
}
return true;
});

View File

@@ -39,7 +39,7 @@ public class ContextInheritParser implements SemanticParser {
SchemaElementType.VALUE, Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY, Arrays.asList(SchemaElementType.ENTITY)),
new AbstractMap.SimpleEntry<>(SchemaElementType.TAG, Arrays.asList(SchemaElementType.TAG)),
new AbstractMap.SimpleEntry<>(SchemaElementType.VIEW, Arrays.asList(SchemaElementType.VIEW)),
new AbstractMap.SimpleEntry<>(SchemaElementType.DATASET, Arrays.asList(SchemaElementType.DATASET)),
new AbstractMap.SimpleEntry<>(SchemaElementType.ID, Arrays.asList(SchemaElementType.ID))
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
@@ -48,12 +48,12 @@ public class ContextInheritParser implements SemanticParser {
if (!shouldInherit(queryContext)) {
return;
}
Long viewId = getMatchedView(queryContext, chatContext);
if (viewId == null) {
Long dataSetId = getMatchedDataSet(queryContext, chatContext);
if (dataSetId == null) {
return;
}
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo().getMatchedElements(dataSetId);
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
@@ -70,17 +70,17 @@ public class ContextInheritParser implements SemanticParser {
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(queryContext, chatContext);
if (existSameQuery(query.getParseInfo().getViewId(), query.getQueryMode(), queryContext)) {
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), queryContext)) {
continue;
}
queryContext.getCandidateQueries().add(query);
}
}
private boolean existSameQuery(Long viewId, String queryMode, QueryContext queryContext) {
private boolean existSameQuery(Long dataSetId, String queryMode, QueryContext queryContext) {
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
if (semanticQuery.getQueryMode().equals(queryMode)
&& semanticQuery.getParseInfo().getViewId().equals(viewId)) {
&& semanticQuery.getParseInfo().getDataSetId().equals(dataSetId)) {
return true;
}
}
@@ -109,16 +109,16 @@ public class ContextInheritParser implements SemanticParser {
return metricModelQueries.size() == queryContext.getCandidateQueries().size();
}
protected Long getMatchedView(QueryContext queryContext, ChatContext chatContext) {
Long viewId = chatContext.getParseInfo().getViewId();
if (viewId == null) {
protected Long getMatchedDataSet(QueryContext queryContext, ChatContext chatContext) {
Long dataSetId = chatContext.getParseInfo().getDataSetId();
if (dataSetId == null) {
return null;
}
Set<Long> queryViews = queryContext.getMapInfo().getMatchedViewInfos();
if (queryViews.contains(viewId)) {
return viewId;
Set<Long> queryDataSets = queryContext.getMapInfo().getMatchedDataSetInfos();
if (queryDataSets.contains(dataSetId)) {
return dataSetId;
}
return viewId;
return dataSetId;
}
}

View File

@@ -29,8 +29,8 @@ public class RuleSqlParser implements SemanticParser {
public void parse(QueryContext queryContext, ChatContext chatContext) {
SchemaMapInfo mapInfo = queryContext.getMapInfo();
// iterate all schemaElementMatches to resolve query mode
for (Long viewId : mapInfo.getMatchedViewInfos()) {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(viewId);
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(queryContext, chatContext);

View File

@@ -20,7 +20,7 @@ public class Plugin extends RecordInfo {
*/
private String type;
private List<Long> viewList = Lists.newArrayList();
private List<Long> dataSetList = Lists.newArrayList();
/**
* description, for parsing
@@ -52,7 +52,7 @@ public class Plugin extends RecordInfo {
}
public boolean isContainsAllModel() {
return CollectionUtils.isNotEmpty(viewList) && viewList.contains(-1L);
return CollectionUtils.isNotEmpty(dataSetList) && dataSetList.contains(-1L);
}
public Long getDefaultMode() {

View File

@@ -266,14 +266,14 @@ public class PluginManager {
}
private static Set<Long> getPluginMatchedModel(Plugin plugin, QueryContext queryContext) {
Set<Long> matchedViews = queryContext.getMapInfo().getMatchedViewInfos();
Set<Long> matchedDataSets = queryContext.getMapInfo().getMatchedDataSetInfos();
if (plugin.isContainsAllModel()) {
return Sets.newHashSet(plugin.getDefaultMode());
}
List<Long> modelIds = plugin.getViewList();
List<Long> modelIds = plugin.getDataSetList();
Set<Long> pluginMatchedModel = Sets.newHashSet();
for (Long modelId : modelIds) {
if (matchedViews.contains(modelId)) {
if (matchedDataSets.contains(modelId)) {
pluginMatchedModel.add(modelId);
}
}

View File

@@ -15,7 +15,7 @@ public class PluginRecallResult {
private Plugin plugin;
private Set<Long> viewIds;
private Set<Long> dataSetIds;
private double score;

View File

@@ -30,7 +30,7 @@ public class QueryContext {
private String queryText;
private Integer chatId;
private Long viewId;
private Long dataSetId;
private User user;
private boolean saveAnswer = true;
private Integer agentId;

View File

@@ -49,7 +49,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryType.SQL)
.queryReq(QueryReqBuilder.buildS2SQLReq(
sqlInfo.getCorrectS2SQL(), parseInfo.getViewId()
sqlInfo.getCorrectS2SQL(), parseInfo.getDataSetId()
))
.build();
} else {
@@ -84,7 +84,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
}
protected void convertBizNameToName(SemanticSchema semanticSchema, QueryStructReq queryStructReq) {
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getViewId());
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getDataSetId());
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
List<Order> orders = queryStructReq.getOrders();

View File

@@ -36,7 +36,7 @@ public class LLMReq {
private String domainName;
private String viewName;
private String dataSetName;
private List<String> fieldNameList;

View File

@@ -42,7 +42,7 @@ public class LLMSqlQuery extends LLMSemanticQuery {
long startTime = System.currentTimeMillis();
String querySql = parseInfo.getSqlInfo().getCorrectS2SQL();
QuerySqlReq querySQLReq = QueryReqBuilder.buildS2SQLReq(querySql, parseInfo.getViewId());
QuerySqlReq querySQLReq = QueryReqBuilder.buildS2SQLReq(querySql, parseInfo.getDataSetId());
SemanticQueryResp queryResp = semanticInterpreter.queryByS2SQL(querySQLReq, user);
log.info("queryByS2SQL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);

View File

@@ -79,7 +79,7 @@ public abstract class PluginSemanticQuery extends BaseSemanticQuery {
if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) {
for (ParamOption paramOption : webPage.getParamOptions()) {
if (paramOption.getModelId() != null
&& !parseInfo.getViewId().equals(paramOption.getModelId())) {
&& !parseInfo.getDataSetId().equals(paramOption.getModelId())) {
continue;
}
paramOptions.add(paramOption);

View File

@@ -25,7 +25,7 @@ public class QueryMatcher {
public QueryMatcher() {
for (SchemaElementType type : SchemaElementType.values()) {
if (type.equals(SchemaElementType.VIEW)) {
if (type.equals(SchemaElementType.DATASET)) {
elementOptionMap.put(type, QueryMatchOption.optional());
} else {
elementOptionMap.put(type, QueryMatchOption.unused());

View File

@@ -102,9 +102,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
}
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
Set<Long> viewIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
.map(SchemaElement::getView).collect(Collectors.toSet());
parseInfo.setView(semanticSchema.getView(viewIds.iterator().next()));
Set<Long> dataSetIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
.map(SchemaElement::getDataSet).collect(Collectors.toSet());
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetIds.iterator().next()));
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
Map<Long, List<SchemaElementMatch>> id2Values = new HashMap<>();
@@ -189,7 +189,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
public QueryResult execute(User user) {
String queryMode = parseInfo.getQueryMode();
if (parseInfo.getViewId() == null || StringUtils.isEmpty(queryMode)
if (parseInfo.getDataSetId() == null || StringUtils.isEmpty(queryMode)
|| !QueryManager.containsRuleQuery(queryMode)) {
// reach here some error may happen
log.error("not find QueryMode");
@@ -230,7 +230,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
public QueryResult multiStructExecute(User user) {
String queryMode = parseInfo.getQueryMode();
if (parseInfo.getViewId() != null || StringUtils.isEmpty(queryMode)
if (parseInfo.getDataSetId() != null || StringUtils.isEmpty(queryMode)
|| !QueryManager.containsRuleQuery(queryMode)) {
// reach here some error may happen
log.error("not find QueryMode");

View File

@@ -1,12 +1,12 @@
package com.tencent.supersonic.chat.core.query.rule.metric;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import org.springframework.stereotype.Component;
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_MOST;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import org.springframework.stereotype.Component;
@Component
public class MetricModelQuery extends MetricSemanticQuery {
@@ -14,7 +14,7 @@ public class MetricModelQuery extends MetricSemanticQuery {
public MetricModelQuery() {
super();
queryMatcher.addOption(SchemaElementType.VIEW, OPTIONAL, AT_MOST, 1);
queryMatcher.addOption(SchemaElementType.DATASET, OPTIONAL, AT_MOST, 1);
}
@Override

View File

@@ -1,21 +1,22 @@
package com.tencent.supersonic.chat.core.query.rule.metric;
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.METRIC;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import lombok.extern.slf4j.Slf4j;
import java.time.LocalDate;
import java.util.List;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.METRIC;
@Slf4j
public abstract class MetricSemanticQuery extends RuleSemanticQuery {
@@ -38,8 +39,9 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
super.fillParseInfo(queryContext, chatContext);
parseInfo.setLimit(METRIC_MAX_RESULTS);
if (parseInfo.getDateInfo() == null) {
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(parseInfo.getViewId());
TimeDefaultConfig timeDefaultConfig = viewSchema.getMetricTypeTimeDefaultConfig();
DataSetSchema dataSetSchema =
queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig();
DateConf dateInfo = new DateConf();
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())) {
int unit = timeDefaultConfig.getUnit();

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.core.query.rule.tag;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
@@ -25,19 +25,19 @@ public abstract class TagListQuery extends TagSemanticQuery {
}
private void addEntityDetailAndOrderByMetric(QueryContext queryContext, SemanticParseInfo parseInfo) {
Long viewId = parseInfo.getViewId();
if (Objects.nonNull(viewId) && viewId > 0L) {
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
if (viewSchema != null && Objects.nonNull(viewSchema.getEntity())) {
Long dataSetId = parseInfo.getDataSetId();
if (Objects.nonNull(dataSetId) && dataSetId > 0L) {
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
if (dataSetSchema != null && Objects.nonNull(dataSetSchema.getEntity())) {
Set<SchemaElement> dimensions = new LinkedHashSet<>();
Set<SchemaElement> metrics = new LinkedHashSet<>();
Set<Order> orders = new LinkedHashSet<>();
TagTypeDefaultConfig tagTypeDefaultConfig = viewSchema.getTagTypeDefaultConfig();
TagTypeDefaultConfig tagTypeDefaultConfig = dataSetSchema.getTagTypeDefaultConfig();
if (tagTypeDefaultConfig != null && tagTypeDefaultConfig.getDefaultDisplayInfo() != null) {
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) {
metrics = tagTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds()
.stream().map(id -> {
SchemaElement metric = viewSchema.getElement(SchemaElementType.METRIC, id);
SchemaElement metric = dataSetSchema.getElement(SchemaElementType.METRIC, id);
if (metric != null) {
orders.add(new Order(metric.getBizName(), Constants.DESC_UPPER));
}
@@ -46,7 +46,7 @@ public abstract class TagListQuery extends TagSemanticQuery {
}
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) {
dimensions = tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
.map(id -> viewSchema.getElement(SchemaElementType.DIMENSION, id))
.map(id -> dataSetSchema.getElement(SchemaElementType.DIMENSION, id))
.filter(Objects::nonNull).collect(Collectors.toSet());
}
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.core.query.rule.tag;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
@@ -42,8 +42,9 @@ public abstract class TagSemanticQuery extends RuleSemanticQuery {
parseInfo.setQueryType(QueryType.TAG);
parseInfo.setLimit(TAG_MAX_RESULTS);
if (parseInfo.getDateInfo() == null) {
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(parseInfo.getViewId());
TimeDefaultConfig timeDefaultConfig = viewSchema.getTagTypeTimeDefaultConfig();
DataSetSchema dataSetSchema =
queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
DateConf dateInfo = new DateConf();
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())) {
int unit = timeDefaultConfig.getUnit();

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.core.query.semantic;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@@ -16,53 +16,53 @@ import java.util.concurrent.TimeUnit;
@Slf4j
public abstract class BaseSemanticInterpreter implements SemanticInterpreter {
protected final Cache<String, List<ViewSchemaResp>> viewSchemaCache =
protected final Cache<String, List<DataSetSchemaResp>> dataSetSchemaCache =
CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.SECONDS).build();
@SneakyThrows
public List<ViewSchemaResp> fetchViewSchema(List<Long> ids, Boolean cacheEnable) {
public List<DataSetSchemaResp> fetchDataSetSchema(List<Long> ids, Boolean cacheEnable) {
if (cacheEnable) {
return viewSchemaCache.get(String.valueOf(ids), () -> {
List<ViewSchemaResp> data = doFetchViewSchema(ids);
viewSchemaCache.put(String.valueOf(ids), data);
return dataSetSchemaCache.get(String.valueOf(ids), () -> {
List<DataSetSchemaResp> data = doFetchDataSetSchema(ids);
dataSetSchemaCache.put(String.valueOf(ids), data);
return data;
});
}
return doFetchViewSchema(ids);
return doFetchDataSetSchema(ids);
}
@Override
public ViewSchema getViewSchema(Long viewId, Boolean cacheEnable) {
public DataSetSchema getDataSetSchema(Long dataSetId, Boolean cacheEnable) {
List<Long> ids = new ArrayList<>();
ids.add(viewId);
List<ViewSchemaResp> viewSchemaResps = fetchViewSchema(ids, cacheEnable);
if (!CollectionUtils.isEmpty(viewSchemaResps)) {
Optional<ViewSchemaResp> viewSchemaResp = viewSchemaResps.stream()
.filter(d -> d.getId().equals(viewId)).findFirst();
if (viewSchemaResp.isPresent()) {
ViewSchemaResp viewSchema = viewSchemaResp.get();
return ViewSchemaBuilder.build(viewSchema);
ids.add(dataSetId);
List<DataSetSchemaResp> dataSetSchemaResps = fetchDataSetSchema(ids, cacheEnable);
if (!CollectionUtils.isEmpty(dataSetSchemaResps)) {
Optional<DataSetSchemaResp> dataSetSchemaResp = dataSetSchemaResps.stream()
.filter(d -> d.getId().equals(dataSetId)).findFirst();
if (dataSetSchemaResp.isPresent()) {
DataSetSchemaResp dataSetSchema = dataSetSchemaResp.get();
return DataSetSchemaBuilder.build(dataSetSchema);
}
}
return null;
}
@Override
public List<ViewSchema> getViewSchema() {
return getViewSchema(new ArrayList<>());
public List<DataSetSchema> getDataSetSchema() {
return getDataSetSchema(new ArrayList<>());
}
@Override
public List<ViewSchema> getViewSchema(List<Long> ids) {
List<ViewSchema> domainSchemaList = new ArrayList<>();
public List<DataSetSchema> getDataSetSchema(List<Long> ids) {
List<DataSetSchema> domainSchemaList = new ArrayList<>();
for (ViewSchemaResp resp : fetchViewSchema(ids, true)) {
domainSchemaList.add(ViewSchemaBuilder.build(resp));
for (DataSetSchemaResp resp : fetchDataSetSchema(ids, true)) {
domainSchemaList.add(DataSetSchemaBuilder.build(resp));
}
return domainSchemaList;
}
protected abstract List<ViewSchemaResp> doFetchViewSchema(List<Long> ids);
protected abstract List<DataSetSchemaResp> doFetchDataSetSchema(List<Long> ids);
}

View File

@@ -5,13 +5,13 @@ import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.DimValueMap;
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
@@ -23,19 +23,19 @@ import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
public class ViewSchemaBuilder {
public class DataSetSchemaBuilder {
public static ViewSchema build(ViewSchemaResp resp) {
ViewSchema viewSchema = new ViewSchema();
viewSchema.setQueryConfig(resp.getQueryConfig());
SchemaElement model = SchemaElement.builder()
.view(resp.getId())
public static DataSetSchema build(DataSetSchemaResp resp) {
DataSetSchema dataSetSchema = new DataSetSchema();
dataSetSchema.setQueryConfig(resp.getQueryConfig());
SchemaElement dataSet = SchemaElement.builder()
.dataSet(resp.getId())
.id(resp.getId())
.name(resp.getName())
.bizName(resp.getBizName())
.type(SchemaElementType.VIEW)
.type(SchemaElementType.DATASET)
.build();
viewSchema.setView(model);
dataSetSchema.setDataSet(dataSet);
Set<SchemaElement> metrics = new HashSet<>();
for (MetricSchemaResp metric : resp.getMetrics()) {
@@ -43,7 +43,7 @@ public class ViewSchemaBuilder {
List<String> alias = SchemaItem.getAliasList(metric.getAlias());
SchemaElement metricToAdd = SchemaElement.builder()
.view(resp.getId())
.dataSet(resp.getId())
.model(metric.getModelId())
.id(metric.getId())
.name(metric.getName())
@@ -57,7 +57,7 @@ public class ViewSchemaBuilder {
metrics.add(metricToAdd);
}
viewSchema.getMetrics().addAll(metrics);
dataSetSchema.getMetrics().addAll(metrics);
Set<SchemaElement> dimensions = new HashSet<>();
Set<SchemaElement> dimensionValues = new HashSet<>();
@@ -84,7 +84,7 @@ public class ViewSchemaBuilder {
}
SchemaElement dimToAdd = SchemaElement.builder()
.view(resp.getId())
.dataSet(resp.getId())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
@@ -97,7 +97,7 @@ public class ViewSchemaBuilder {
dimensions.add(dimToAdd);
SchemaElement dimValueToAdd = SchemaElement.builder()
.view(resp.getId())
.dataSet(resp.getId())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
@@ -109,7 +109,7 @@ public class ViewSchemaBuilder {
dimensionValues.add(dimValueToAdd);
if (dim.getIsTag() == 1) {
SchemaElement tagToAdd = SchemaElement.builder()
.view(resp.getId())
.dataSet(resp.getId())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
@@ -122,14 +122,14 @@ public class ViewSchemaBuilder {
tags.add(tagToAdd);
}
}
viewSchema.getDimensions().addAll(dimensions);
viewSchema.getDimensionValues().addAll(dimensionValues);
viewSchema.getTags().addAll(tags);
dataSetSchema.getDimensions().addAll(dimensions);
dataSetSchema.getDimensionValues().addAll(dimensionValues);
dataSetSchema.getTags().addAll(tags);
DimSchemaResp dim = resp.getPrimaryKey();
if (dim != null) {
SchemaElement entity = SchemaElement.builder()
.view(resp.getId())
.dataSet(resp.getId())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
@@ -138,9 +138,9 @@ public class ViewSchemaBuilder {
.useCnt(dim.getUseCnt())
.alias(dim.getEntityAlias())
.build();
viewSchema.setEntity(entity);
dataSetSchema.setEntity(entity);
}
return viewSchema;
return dataSetSchema;
}
private static List<RelatedSchemaElement> getRelateSchemaElement(MetricSchemaResp metricSchemaResp) {

View File

@@ -10,15 +10,15 @@ import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.ViewFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.DataSetFilterReq;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.QueryService;
@@ -44,7 +44,7 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
if (StringUtils.isNotBlank(queryStructReq.getCorrectS2SQL())) {
QuerySqlReq querySqlReq = new QuerySqlReq();
querySqlReq.setSql(queryStructReq.getCorrectS2SQL());
querySqlReq.setViewId(queryStructReq.getViewId());
querySqlReq.setDataSetId(queryStructReq.getDataSetId());
querySqlReq.setParams(new ArrayList<>());
return queryByS2SQL(querySqlReq, user);
}
@@ -68,11 +68,11 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
}
@Override
public List<ViewSchemaResp> doFetchViewSchema(List<Long> ids) {
ViewFilterReq filter = new ViewFilterReq();
filter.setViewIds(ids);
public List<DataSetSchemaResp> doFetchDataSetSchema(List<Long> ids) {
DataSetFilterReq filter = new DataSetFilterReq();
filter.setDataSetIds(ids);
schemaService = ContextUtils.getBean(SchemaService.class);
return schemaService.fetchViewSchema(filter);
return schemaService.fetchDataSetSchema(filter);
}
@Override
@@ -82,9 +82,9 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
}
@Override
public List<ViewResp> getViewList(Long domainId) {
public List<DataSetResp> getDataSetList(Long domainId) {
schemaService = ContextUtils.getBean(SchemaService.class);
return schemaService.getViewList(domainId);
return schemaService.getDataSetList(domainId);
}
@Override
@@ -106,8 +106,8 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
}
@Override
public List<ItemResp> getDomainViewTree() {
return schemaService.getDomainViewTree();
public List<ItemResp> getDomainDataSetTree() {
return schemaService.getDomainDataSetTree();
}
}

View File

@@ -24,8 +24,8 @@ import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
@@ -250,17 +250,17 @@ public class RemoteSemanticInterpreter extends BaseSemanticInterpreter {
}
@Override
protected List<ViewSchemaResp> doFetchViewSchema(List<Long> ids) {
protected List<DataSetSchemaResp> doFetchDataSetSchema(List<Long> ids) {
return null;
}
@Override
public List<ItemResp> getDomainViewTree() {
public List<ItemResp> getDomainDataSetTree() {
return null;
}
@Override
public List<ViewResp> getViewList(Long domainId) {
public List<DataSetResp> getDataSetList(Long domainId) {
return null;
}

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.core.query.semantic;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
@@ -15,14 +15,14 @@ import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import java.util.List;
/**
* A semantic layer provides a simplified and consistent view of data from multiple sources.
* It abstracts away the complexity of the underlying data sources and provides a unified view
* A semantic layer provides a simplified and consistent dataSet of data from multiple sources.
* It abstracts away the complexity of the underlying data sources and provides a unified dataSet
* of the data that is easier to understand and use.
* <p>
* The interface defines methods for getting metadata as well as querying data in the semantic layer.
@@ -39,11 +39,11 @@ public interface SemanticInterpreter {
SemanticQueryResp queryByS2SQL(QuerySqlReq querySQLReq, User user);
List<ViewSchema> getViewSchema();
List<DataSetSchema> getDataSetSchema();
List<ViewSchema> getViewSchema(List<Long> ids);
List<DataSetSchema> getDataSetSchema(List<Long> ids);
ViewSchema getViewSchema(Long model, Boolean cacheEnable);
DataSetSchema getDataSetSchema(Long model, Boolean cacheEnable);
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
@@ -53,10 +53,10 @@ public interface SemanticInterpreter {
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
List<ViewSchemaResp> fetchViewSchema(List<Long> ids, Boolean cacheEnable);
List<DataSetSchemaResp> fetchDataSetSchema(List<Long> ids, Boolean cacheEnable);
List<ViewResp> getViewList(Long domainId);
List<DataSetResp> getDataSetList(Long domainId);
List<ItemResp> getDomainViewTree();
List<ItemResp> getDomainDataSetTree();
}

View File

@@ -1,22 +1,23 @@
package com.tencent.supersonic.chat.core.utils;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.parser.JavaLLMProxy;
import com.tencent.supersonic.chat.core.parser.LLMProxy;
import com.tencent.supersonic.chat.core.parser.sql.llm.ViewResolver;
import com.tencent.supersonic.chat.core.parser.sql.llm.DataSetResolver;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.core.io.support.SpringFactoriesLoader;
import java.util.Map;
import java.util.Objects;
@Slf4j
public class ComponentFactory {
private static SemanticInterpreter semanticInterpreter;
private static LLMProxy llmProxy;
private static ViewResolver modelResolver;
private static DataSetResolver modelResolver;
public static SemanticInterpreter getSemanticLayer() {
if (Objects.isNull(semanticInterpreter)) {
@@ -44,9 +45,9 @@ public class ComponentFactory {
return llmProxy;
}
public static ViewResolver getModelResolver() {
public static DataSetResolver getModelResolver() {
if (Objects.isNull(modelResolver)) {
modelResolver = init(ViewResolver.class);
modelResolver = init(DataSetResolver.class);
}
return modelResolver;
}

View File

@@ -37,8 +37,8 @@ public class QueryReqBuilder {
public static QueryStructReq buildStructReq(SemanticParseInfo parseInfo) {
QueryStructReq queryStructReq = new QueryStructReq();
queryStructReq.setViewId(parseInfo.getViewId());
queryStructReq.setViewName(parseInfo.getView().getName());
queryStructReq.setDataSetId(parseInfo.getDataSetId());
queryStructReq.setDataSetName(parseInfo.getDataSet().getName());
queryStructReq.setQueryType(parseInfo.getQueryType());
queryStructReq.setDateInfo(rewrite2Between(parseInfo.getDateInfo()));
@@ -119,7 +119,7 @@ public class QueryReqBuilder {
for (Filter dimensionFilter : queryStructReq.getDimensionFilters()) {
QueryStructReq req = new QueryStructReq();
BeanUtils.copyProperties(queryStructReq, req);
req.setViewId(parseInfo.getViewId());
req.setDataSetId(parseInfo.getDataSetId());
req.setDimensionFilters(Lists.newArrayList(dimensionFilter));
queryStructReqs.add(req);
}
@@ -131,15 +131,15 @@ public class QueryReqBuilder {
* convert to QueryS2SQLReq
*
* @param querySql
* @param viewId
* @param dataSetId
* @return
*/
public static QuerySqlReq buildS2SQLReq(String querySql, Long viewId) {
public static QuerySqlReq buildS2SQLReq(String querySql, Long dataSetId) {
QuerySqlReq querySQLReq = new QuerySqlReq();
if (Objects.nonNull(querySql)) {
querySQLReq.setSql(querySql);
}
querySQLReq.setViewId(viewId);
querySQLReq.setDataSetId(dataSetId);
return querySQLReq;
}

View File

@@ -1,43 +1,44 @@
package com.tencent.supersonic.chat.core.utils;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.util.DatePeriodEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import java.util.Objects;
import org.apache.commons.lang3.tuple.Pair;
import java.util.Objects;
public class S2SqlDateHelper {
public static String getReferenceDate(QueryContext queryContext, Long viewId) {
public static String getReferenceDate(QueryContext queryContext, Long dataSetId) {
String defaultDate = DateUtils.getBeforeDate(0);
if (Objects.isNull(viewId)) {
if (Objects.isNull(dataSetId)) {
return defaultDate;
}
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
if (viewSchema == null || viewSchema.getTagTypeTimeDefaultConfig() == null) {
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
if (dataSetSchema == null || dataSetSchema.getTagTypeTimeDefaultConfig() == null) {
return defaultDate;
}
TimeDefaultConfig tagTypeTimeDefaultConfig = viewSchema.getTagTypeTimeDefaultConfig();
TimeDefaultConfig tagTypeTimeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig).getLeft();
}
public static Pair<String, String> getStartEndDate(QueryContext queryContext,
Long viewId, QueryType queryType) {
Long dataSetId, QueryType queryType) {
String defaultDate = DateUtils.getBeforeDate(0);
if (Objects.isNull(viewId)) {
if (Objects.isNull(dataSetId)) {
return Pair.of(defaultDate, defaultDate);
}
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
if (viewSchema == null) {
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
if (dataSetSchema == null) {
return Pair.of(defaultDate, defaultDate);
}
TimeDefaultConfig defaultConfig = viewSchema.getMetricTypeTimeDefaultConfig();
TimeDefaultConfig defaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig();
if (QueryType.TAG.equals(queryType)) {
defaultConfig = viewSchema.getTagTypeTimeDefaultConfig();
defaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
}
return getDefaultDate(defaultDate, defaultConfig);
}

View File

@@ -57,7 +57,7 @@ public class SimilarQueryManager {
embeddingQuery.setQuery(queryText);
Map<String, Object> metaData = new HashMap<>();
metaData.put("modelId", similarQueryReq.getViewId());
metaData.put("modelId", similarQueryReq.getDataSetId());
metaData.put("agentId", similarQueryReq.getAgentId());
embeddingQuery.setMetadata(metaData);
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();

View File

@@ -6,7 +6,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
@@ -23,7 +23,7 @@ import org.junit.jupiter.api.Test;
class SchemaCorrectorTest {
private String json = "{\n"
+ " \"viewId\": 1,\n"
+ " \"dataSetId\": 1,\n"
+ " \"llmReq\": {\n"
+ " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n"
+ " \"filterCondition\": {\n"
@@ -31,7 +31,7 @@ class SchemaCorrectorTest {
+ " },\n"
+ " \"schema\": {\n"
+ " \"domainName\": \"歌曲\",\n"
+ " \"viewName\": \"歌曲\",\n"
+ " \"dataSetName\": \"歌曲\",\n"
+ " \"fieldNameList\": [\n"
+ " \"商务组\",\n"
+ " \"歌曲名\",\n"
@@ -52,7 +52,7 @@ class SchemaCorrectorTest {
+ " \"id\": \"y3LqVSRL\",\n"
+ " \"name\": \"大模型语义解析\",\n"
+ " \"type\": \"NL2SQL_LLM\",\n"
+ " \"viewIds\": [\n"
+ " \"dataSetIds\": [\n"
+ " 1\n"
+ " ]\n"
+ " },\n"
@@ -63,8 +63,8 @@ class SchemaCorrectorTest {
@Test
void doCorrect() throws JsonProcessingException {
Long viewId = 1L;
QueryContext queryContext = buildQueryContext(viewId);
Long dataSetId = 1L;
QueryContext queryContext = buildQueryContext(dataSetId);
ObjectMapper objectMapper = new ObjectMapper();
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);
@@ -78,8 +78,8 @@ class SchemaCorrectorTest {
semanticParseInfo.setSqlInfo(sqlInfo);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setView(viewId);
semanticParseInfo.setView(schemaElement);
schemaElement.setDataSet(dataSetId);
semanticParseInfo.setDataSet(schemaElement);
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
@@ -108,35 +108,35 @@ class SchemaCorrectorTest {
}
private QueryContext buildQueryContext(Long viewId) {
private QueryContext buildQueryContext(Long dataSetId) {
QueryContext queryContext = new QueryContext();
List<ViewSchema> viewSchemaList = new ArrayList<>();
ViewSchema viewSchema = new ViewSchema();
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
DataSetSchema dataSetSchema = new DataSetSchema();
QueryConfig queryConfig = new QueryConfig();
viewSchema.setQueryConfig(queryConfig);
dataSetSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setView(viewId);
viewSchema.setView(schemaElement);
schemaElement.setDataSet(dataSetId);
dataSetSchema.setDataSet(schemaElement);
Set<SchemaElement> dimensions = new HashSet<>();
SchemaElement element1 = new SchemaElement();
element1.setView(1L);
element1.setDataSet(1L);
element1.setName("歌曲名");
dimensions.add(element1);
SchemaElement element2 = new SchemaElement();
element2.setView(1L);
element2.setDataSet(1L);
element2.setName("商务组");
dimensions.add(element2);
SchemaElement element3 = new SchemaElement();
element3.setView(1L);
element3.setDataSet(1L);
element3.setName("发行日期");
dimensions.add(element3);
viewSchema.setDimensions(dimensions);
viewSchemaList.add(viewSchema);
dataSetSchema.setDimensions(dimensions);
dataSetSchemaList.add(dataSetSchema);
SemanticSchema semanticSchema = new SemanticSchema(viewSchemaList);
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
queryContext.setSemanticSchema(semanticSchema);
return queryContext;
}

View File

@@ -29,7 +29,7 @@ class LLMSqlParserTest {
SchemaElement schemaElement = SchemaElement.builder()
.bizName("singer_name")
.name("歌手名")
.view(2L)
.dataSet(2L)
.schemaValueMaps(schemaValueMaps)
.build();
dimensions.add(schemaElement);
@@ -37,7 +37,7 @@ class LLMSqlParserTest {
SchemaElement schemaElement2 = SchemaElement.builder()
.bizName("publish_time")
.name("发布时间")
.view(2L)
.dataSet(2L)
.build();
dimensions.add(schemaElement2);
@@ -47,7 +47,7 @@ class LLMSqlParserTest {
SchemaElement metric = SchemaElement.builder()
.bizName("play_count")
.name("播放量")
.view(2L)
.dataSet(2L)
.build();
metrics.add(metric);

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.core.utils;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.QueryType;
@@ -11,60 +11,60 @@ import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
class S2SqlDateHelperTest {
@Test
void getReferenceDate() {
Long viewId = 1L;
QueryContext queryContext = buildQueryContext(viewId);
Long dataSetId = 1L;
QueryContext queryContext = buildQueryContext(dataSetId);
String referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, null);
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(0));
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
QueryConfig queryConfig = viewSchema.getQueryConfig();
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
timeDefaultConfig.setTimeMode(TimeMode.LAST);
timeDefaultConfig.setPeriod(Constants.DAY);
timeDefaultConfig.setUnit(20);
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(20));
timeDefaultConfig.setUnit(1);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
Assert.assertEquals(referenceDate, DateUtils.getBeforeDate(1));
timeDefaultConfig.setUnit(-1);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, viewId);
referenceDate = S2SqlDateHelper.getReferenceDate(queryContext, dataSetId);
Assert.assertNull(referenceDate);
}
@Test
void getStartEndDate() {
Long viewId = 1L;
QueryContext queryContext = buildQueryContext(viewId);
Long dataSetId = 1L;
QueryContext queryContext = buildQueryContext(dataSetId);
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, null, QueryType.TAG);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(0));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(0));
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.TAG);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.TAG);
Assert.assertNull(startEndDate.getLeft());
Assert.assertNull(startEndDate.getRight());
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
QueryConfig queryConfig = viewSchema.getQueryConfig();
DataSetSchema dataSetSchema = queryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
QueryConfig queryConfig = dataSetSchema.getQueryConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
timeDefaultConfig.setTimeMode(TimeMode.LAST);
timeDefaultConfig.setPeriod(Constants.DAY);
@@ -72,49 +72,49 @@ class S2SqlDateHelperTest {
queryConfig.getTagTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
queryConfig.getMetricTypeDefaultConfig().setTimeDefaultConfig(timeDefaultConfig);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.TAG);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.TAG);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(20));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(20));
timeDefaultConfig.setUnit(2);
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.TAG);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.TAG);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(2));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(1));
timeDefaultConfig.setUnit(-1);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
Assert.assertNull(startEndDate.getLeft());
Assert.assertNull(startEndDate.getRight());
timeDefaultConfig.setTimeMode(TimeMode.LAST);
timeDefaultConfig.setPeriod(Constants.DAY);
timeDefaultConfig.setUnit(5);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, viewId, QueryType.METRIC);
startEndDate = S2SqlDateHelper.getStartEndDate(queryContext, dataSetId, QueryType.METRIC);
Assert.assertEquals(startEndDate.getLeft(), DateUtils.getBeforeDate(5));
Assert.assertEquals(startEndDate.getRight(), DateUtils.getBeforeDate(5));
}
private QueryContext buildQueryContext(Long viewId) {
private QueryContext buildQueryContext(Long dataSetId) {
QueryContext queryContext = new QueryContext();
List<ViewSchema> viewSchemaList = new ArrayList<>();
ViewSchema viewSchema = new ViewSchema();
List<DataSetSchema> dataSetSchemaList = new ArrayList<>();
DataSetSchema dataSetSchema = new DataSetSchema();
QueryConfig queryConfig = new QueryConfig();
viewSchema.setQueryConfig(queryConfig);
dataSetSchema.setQueryConfig(queryConfig);
SchemaElement schemaElement = new SchemaElement();
schemaElement.setView(viewId);
viewSchema.setView(schemaElement);
viewSchemaList.add(viewSchema);
schemaElement.setDataSet(dataSetId);
dataSetSchema.setDataSet(schemaElement);
dataSetSchemaList.add(dataSetSchema);
SemanticSchema semanticSchema = new SemanticSchema(viewSchemaList);
SemanticSchema semanticSchema = new SemanticSchema(dataSetSchemaList);
queryContext.setSemanticSchema(semanticSchema);
return queryContext;
}

View File

@@ -15,7 +15,7 @@ public class PluginDO {
private String type;
private String view;
private String dataSet;
private String pattern;

View File

@@ -4,7 +4,7 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.service.SemanticService;
@@ -34,15 +34,15 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
return;
}
SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getView());
List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getDataSet());
queryResult.setRecommendedDimensions(dimensionRecommended);
}
private List<SchemaElement> getDimensions(Long metricId, Long viewId) {
private List<SchemaElement> getDimensions(Long metricId, Long dataSetId) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ViewSchema viewSchema = semanticService.getViewSchema(viewId);
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(dataSetId);
List<Long> drillDownDimensions = Lists.newArrayList();
Set<SchemaElement> metricElements = viewSchema.getMetrics();
Set<SchemaElement> metricElements = dataSetSchema.getMetrics();
if (!CollectionUtils.isEmpty(metricElements)) {
Optional<SchemaElement> metric = metricElements.stream().filter(schemaElement ->
metricId.equals(schemaElement.getId())
@@ -54,7 +54,7 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
}
}
final List<Long> drillDownDimensionsFinal = drillDownDimensions;
return viewSchema.getDimensions().stream()
return dataSetSchema.getDimensions().stream()
.filter(dim -> filterDimension(drillDownDimensionsFinal, dim))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(recommend_dimension_size)

View File

@@ -1,12 +1,11 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ComponentFactory;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -14,6 +13,9 @@ import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
@@ -46,15 +48,14 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
}
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
Map<String, String> filterCondition = new HashMap<>();
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getView().toString());
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getDataSet().toString());
filterCondition.put("type", SchemaElementType.METRIC.name());
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
.filterCondition(filterCondition).queryEmbeddings(null).build();
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
embeddingConfig.getMetaCollectionName(), retrieveQuery, METRIC_RECOMMEND_SIZE + 1);
MetaEmbeddingService metaEmbeddingService = ContextUtils.getBean(MetaEmbeddingService.class);
List<RetrieveQueryResult> retrieveQueryResults =
metaEmbeddingService.retrieveQuery(Lists.newArrayList(parseInfo.getDataSetId()),
retrieveQuery, METRIC_RECOMMEND_SIZE + 1);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;
}
@@ -71,9 +72,10 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
if (!metricIds.contains(Retrieval.getLongId(retrieval.getId()))) {
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
SchemaElement.class);
if (retrieval.getMetadata().containsKey("viewId")) {
String viewId = retrieval.getMetadata().get("viewId").toString();
schemaElement.setView(Long.parseLong(viewId));
if (retrieval.getMetadata().containsKey("dataSetId")) {
String dataSetId = retrieval.getMetadata().get("dataSetId").toString()
.replace(Constants.UNDERLINE, "");
schemaElement.setDataSet(Long.parseLong(dataSetId));
}
schemaElement.setOrder(++metricOrder);
parseInfo.getMetrics().add(schemaElement);

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
@@ -37,9 +37,10 @@ public class EntityInfoProcessor implements ParseResultProcessor {
return;
}
//1. set entity info
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(parseInfo.getViewId());
DataSetSchema dataSetSchema =
queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, viewSchema, queryContext.getUser());
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, queryContext.getUser());
if (QueryManager.isTagQuery(queryMode)
|| QueryManager.isMetricQuery(queryMode)) {
parseInfo.setEntityInfo(entityInfo);

View File

@@ -74,9 +74,9 @@ public class ParseInfoProcessor implements ParseResultProcessor {
}
//set filter
Long viewId = parseInfo.getViewId();
Long dataSetId = parseInfo.getDataSetId();
try {
Map<String, SchemaElement> fieldNameToElement = getNameToElement(viewId);
Map<String, SchemaElement> fieldNameToElement = getNameToElement(dataSetId);
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
@@ -88,31 +88,31 @@ public class ParseInfoProcessor implements ParseResultProcessor {
return;
}
List<String> allFields = getFieldsExceptDate(SqlSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
Set<SchemaElement> metrics = getElements(viewId, allFields, semanticSchema.getMetrics());
Set<SchemaElement> metrics = getElements(dataSetId, allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (QueryType.METRIC.equals(parseInfo.getQueryType())) {
List<String> groupByFields = SqlSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(getElements(viewId, groupByDimensions, semanticSchema.getDimensions()));
parseInfo.setDimensions(getElements(dataSetId, groupByDimensions, semanticSchema.getDimensions()));
} else if (QueryType.TAG.equals(parseInfo.getQueryType())) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(getElements(viewId, selectDimensions, semanticSchema.getDimensions()));
parseInfo.setDimensions(getElements(dataSetId, selectDimensions, semanticSchema.getDimensions()));
}
}
private Set<SchemaElement> getElements(Long viewId, List<String> allFields, List<SchemaElement> elements) {
private Set<SchemaElement> getElements(Long dataSetId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> {
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
return viewId.equals(schemaElement.getView()) && allFields.contains(
return dataSetId.equals(schemaElement.getDataSet()) && allFields.contains(
schemaElement.getName());
}
Set<String> allFieldsSet = new HashSet<>(allFields);
Set<String> aliasSet = new HashSet<>(schemaElement.getAlias());
List<String> intersection = allFieldsSet.stream()
.filter(aliasSet::contains).collect(Collectors.toList());
return viewId.equals(schemaElement.getView()) && (allFields.contains(
return dataSetId.equals(schemaElement.getDataSet()) && (allFields.contains(
schemaElement.getName()) || !CollectionUtils.isEmpty(intersection));
}
).collect(Collectors.toSet());
@@ -194,10 +194,10 @@ public class ParseInfoProcessor implements ParseResultProcessor {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
protected Map<String, SchemaElement> getNameToElement(Long viewId) {
protected Map<String, SchemaElement> getNameToElement(Long dataSetId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions(viewId);
List<SchemaElement> metrics = semanticSchema.getMetrics(viewId);
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
List<SchemaElement> allElements = Lists.newArrayList();
allElements.addAll(dimensions);

View File

@@ -17,7 +17,7 @@ import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
@@ -84,14 +84,14 @@ public class ChatConfigController {
}
//Compatible with front-end
@GetMapping("/viewList")
public List<ViewResp> getViewList() {
return semanticInterpreter.getViewList(null);
@GetMapping("/dataSetList")
public List<DataSetResp> getDataSetList() {
return semanticInterpreter.getDataSetList(null);
}
@GetMapping("/viewList/{domainId}")
public List<ViewResp> getViewList(@PathVariable("domainId") Long domainId) {
return semanticInterpreter.getViewList(domainId);
@GetMapping("/dataSetList/{domainId}")
public List<DataSetResp> getDataSetList(@PathVariable("domainId") Long domainId) {
return semanticInterpreter.getDataSetList(domainId);
}
@PostMapping("/dimension/page")
@@ -107,9 +107,9 @@ public class ChatConfigController {
return semanticInterpreter.getMetricPage(pageMetricReq, user);
}
@GetMapping("/getDomainViewTree")
public List<ItemResp> getDomainViewTree() {
return semanticInterpreter.getDomainViewTree();
@GetMapping("/getDomainDataSetTree")
public List<ItemResp> getDomainDataSetTree() {
return semanticInterpreter.getDomainDataSetTree();
}
}

View File

@@ -6,11 +6,11 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.DataInfo;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ViewInfo;
import com.tencent.supersonic.chat.api.pojo.response.DataSetInfo;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
@@ -50,19 +50,19 @@ public class SemanticService {
return schemaService.getSemanticSchema();
}
public ViewSchema getViewSchema(Long id) {
return schemaService.getViewSchema(id);
public DataSetSchema getDataSetSchema(Long id) {
return schemaService.getDataSetSchema(id);
}
public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, ViewSchema viewSchema, User user) {
if (parseInfo != null && parseInfo.getViewId() > 0) {
EntityInfo entityInfo = getEntityBasicInfo(viewSchema);
if (parseInfo.getDimensionFilters().size() <= 0 || entityInfo.getViewInfo() == null) {
public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user) {
if (parseInfo != null && parseInfo.getDataSetId() > 0) {
EntityInfo entityInfo = getEntityBasicInfo(dataSetSchema);
if (parseInfo.getDimensionFilters().size() <= 0 || entityInfo.getDataSetInfo() == null) {
entityInfo.setMetrics(null);
entityInfo.setDimensions(null);
return entityInfo;
}
String primaryKey = entityInfo.getViewInfo().getPrimaryKey();
String primaryKey = entityInfo.getDataSetInfo().getPrimaryKey();
if (StringUtils.isNotBlank(primaryKey)) {
String entityId = "";
for (QueryFilter chatFilter : parseInfo.getDimensionFilters()) {
@@ -75,7 +75,7 @@ public class SemanticService {
}
entityInfo.setEntityId(entityId);
try {
fillEntityInfoValue(entityInfo, viewSchema, user);
fillEntityInfoValue(entityInfo, dataSetSchema, user);
return entityInfo;
} catch (Exception e) {
log.error("setMainModel error", e);
@@ -85,29 +85,29 @@ public class SemanticService {
return null;
}
private EntityInfo getEntityBasicInfo(ViewSchema viewSchema) {
private EntityInfo getEntityBasicInfo(DataSetSchema dataSetSchema) {
EntityInfo entityInfo = new EntityInfo();
if (viewSchema == null) {
if (dataSetSchema == null) {
return entityInfo;
}
Long viewId = viewSchema.getView().getView();
ViewInfo viewInfo = new ViewInfo();
viewInfo.setItemId(viewId.intValue());
viewInfo.setName(viewSchema.getView().getName());
viewInfo.setWords(viewSchema.getView().getAlias());
viewInfo.setBizName(viewSchema.getView().getBizName());
if (Objects.nonNull(viewSchema.getEntity())) {
viewInfo.setPrimaryKey(viewSchema.getEntity().getBizName());
Long dataSetId = dataSetSchema.getDataSet().getDataSet();
DataSetInfo dataSetInfo = new DataSetInfo();
dataSetInfo.setItemId(dataSetId.intValue());
dataSetInfo.setName(dataSetSchema.getDataSet().getName());
dataSetInfo.setWords(dataSetSchema.getDataSet().getAlias());
dataSetInfo.setBizName(dataSetSchema.getDataSet().getBizName());
if (Objects.nonNull(dataSetSchema.getEntity())) {
dataSetInfo.setPrimaryKey(dataSetSchema.getEntity().getBizName());
}
entityInfo.setViewInfo(viewInfo);
TagTypeDefaultConfig tagTypeDefaultConfig = viewSchema.getTagTypeDefaultConfig();
entityInfo.setDataSetInfo(dataSetInfo);
TagTypeDefaultConfig tagTypeDefaultConfig = dataSetSchema.getTagTypeDefaultConfig();
if (tagTypeDefaultConfig == null || tagTypeDefaultConfig.getDefaultDisplayInfo() == null) {
return entityInfo;
}
List<DataInfo> dimensions = tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
.map(id -> {
SchemaElement element = viewSchema.getElement(SchemaElementType.DIMENSION, id);
SchemaElement element = dataSetSchema.getElement(SchemaElementType.DIMENSION, id);
if (element == null) {
return null;
}
@@ -115,7 +115,7 @@ public class SemanticService {
}).filter(Objects::nonNull).collect(Collectors.toList());
List<DataInfo> metrics = tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
.map(id -> {
SchemaElement element = viewSchema.getElement(SchemaElementType.METRIC, id);
SchemaElement element = dataSetSchema.getElement(SchemaElementType.METRIC, id);
if (element == null) {
return null;
}
@@ -126,9 +126,9 @@ public class SemanticService {
return entityInfo;
}
public void fillEntityInfoValue(EntityInfo entityInfo, ViewSchema viewSchema, User user) {
public void fillEntityInfoValue(EntityInfo entityInfo, DataSetSchema dataSetSchema, User user) {
SemanticQueryResp queryResultWithColumns =
getQueryResultWithSchemaResp(entityInfo, viewSchema, user);
getQueryResultWithSchemaResp(entityInfo, dataSetSchema, user);
if (queryResultWithColumns != null) {
if (!CollectionUtils.isEmpty(queryResultWithColumns.getResultList())
&& queryResultWithColumns.getResultList().size() > 0) {
@@ -147,15 +147,16 @@ public class SemanticService {
}
}
public SemanticQueryResp getQueryResultWithSchemaResp(EntityInfo entityInfo, ViewSchema viewSchema, User user) {
public SemanticQueryResp getQueryResultWithSchemaResp(EntityInfo entityInfo,
DataSetSchema dataSetSchema, User user) {
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setView(viewSchema.getView());
semanticParseInfo.setDataSet(dataSetSchema.getDataSet());
semanticParseInfo.setQueryType(QueryType.TAG);
semanticParseInfo.setMetrics(getMetrics(entityInfo));
semanticParseInfo.setDimensions(getDimensions(entityInfo));
DateConf dateInfo = new DateConf();
int unit = 1;
TimeDefaultConfig timeDefaultConfig = viewSchema.getTagTypeTimeDefaultConfig();
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
if (Objects.nonNull(timeDefaultConfig)) {
unit = timeDefaultConfig.getUnit();
String date = LocalDate.now().plusDays(-unit).toString();
@@ -222,7 +223,7 @@ public class SemanticService {
}
private String getEntityPrimaryName(EntityInfo entityInfo) {
return entityInfo.getViewInfo().getPrimaryKey();
return entityInfo.getDataSetInfo().getPrimaryKey();
}
}

View File

@@ -60,8 +60,8 @@ public class ChatServiceImpl implements ChatService {
return null;
}
SemanticParseInfo originalSemanticParse = chatContext.getParseInfo();
if (Objects.nonNull(originalSemanticParse) && Objects.nonNull(originalSemanticParse.getViewId())) {
return originalSemanticParse.getViewId();
if (Objects.nonNull(originalSemanticParse) && Objects.nonNull(originalSemanticParse.getDataSetId())) {
return originalSemanticParse.getDataSetId();
}
return null;
}

View File

@@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.server.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
@@ -177,7 +177,7 @@ public class ConfigServiceImpl implements ConfigService {
}
private ItemVisibilityInfo fetchVisibilityDescByConfig(ItemVisibility visibility,
ViewSchema modelSchema) {
DataSetSchema modelSchema) {
ItemVisibilityInfo itemVisibilityDesc = new ItemVisibilityInfo();
List<Long> dimIdAllList = chatConfigHelper.generateAllDimIdList(modelSchema);
@@ -219,20 +219,20 @@ public class ConfigServiceImpl implements ConfigService {
}
BeanUtils.copyProperties(chatConfigResp, chatConfigRich);
ViewSchema viewSchema = semanticService.getViewSchema(modelId);
if (viewSchema == null) {
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(modelId);
if (dataSetSchema == null) {
return chatConfigRich;
}
chatConfigRich.setBizName(viewSchema.getView().getBizName());
chatConfigRich.setModelName(viewSchema.getView().getName());
chatConfigRich.setBizName(dataSetSchema.getDataSet().getBizName());
chatConfigRich.setModelName(dataSetSchema.getDataSet().getName());
chatConfigRich.setChatAggRichConfig(fillChatAggRichConfig(viewSchema, chatConfigResp));
chatConfigRich.setChatDetailRichConfig(fillChatDetailRichConfig(viewSchema, chatConfigRich, chatConfigResp));
chatConfigRich.setChatAggRichConfig(fillChatAggRichConfig(dataSetSchema, chatConfigResp));
chatConfigRich.setChatDetailRichConfig(fillChatDetailRichConfig(dataSetSchema, chatConfigRich, chatConfigResp));
return chatConfigRich;
}
private ChatDetailRichConfigResp fillChatDetailRichConfig(ViewSchema modelSchema,
private ChatDetailRichConfigResp fillChatDetailRichConfig(DataSetSchema modelSchema,
ChatConfigRichResp chatConfigRich,
ChatConfigResp chatConfigResp) {
if (Objects.isNull(chatConfigResp) || Objects.isNull(chatConfigResp.getChatDetailConfig())) {
@@ -251,7 +251,7 @@ public class ConfigServiceImpl implements ConfigService {
return detailRichConfig;
}
private EntityRichInfoResp generateRichEntity(Entity entity, ViewSchema modelSchema) {
private EntityRichInfoResp generateRichEntity(Entity entity, DataSetSchema modelSchema) {
EntityRichInfoResp entityRichInfo = new EntityRichInfoResp();
if (Objects.isNull(entity) || Objects.isNull(entity.getEntityId())) {
return entityRichInfo;
@@ -264,7 +264,7 @@ public class ConfigServiceImpl implements ConfigService {
return entityRichInfo;
}
private ChatAggRichConfigResp fillChatAggRichConfig(ViewSchema modelSchema, ChatConfigResp chatConfigResp) {
private ChatAggRichConfigResp fillChatAggRichConfig(DataSetSchema modelSchema, ChatConfigResp chatConfigResp) {
if (Objects.isNull(chatConfigResp) || Objects.isNull(chatConfigResp.getChatAggConfig())) {
return null;
}
@@ -281,7 +281,7 @@ public class ConfigServiceImpl implements ConfigService {
}
private ChatDefaultRichConfigResp fetchDefaultConfig(ChatDefaultConfigReq chatDefaultConfig,
ViewSchema modelSchema,
DataSetSchema modelSchema,
ItemVisibilityInfo itemVisibilityInfo) {
ChatDefaultRichConfigResp defaultRichConfig = new ChatDefaultRichConfigResp();
if (Objects.isNull(chatDefaultConfig)) {
@@ -331,7 +331,7 @@ public class ConfigServiceImpl implements ConfigService {
}
private List<KnowledgeInfoReq> fillKnowledgeBizName(List<KnowledgeInfoReq> knowledgeInfos,
ViewSchema modelSchema) {
DataSetSchema modelSchema) {
if (CollectionUtils.isEmpty(knowledgeInfos)) {
return new ArrayList<>();
}
@@ -351,9 +351,9 @@ public class ConfigServiceImpl implements ConfigService {
@Override
public List<ChatConfigRichResp> getAllChatRichConfig() {
List<ChatConfigRichResp> chatConfigRichInfoList = new ArrayList<>();
List<ViewSchema> modelSchemas = semanticInterpreter.getViewSchema();
List<DataSetSchema> modelSchemas = semanticInterpreter.getDataSetSchema();
modelSchemas.stream().forEach(modelSchema -> {
ChatConfigRichResp chatConfigRichInfo = getConfigRichInfo(modelSchema.getView().getId());
ChatConfigRichResp chatConfigRichInfo = getConfigRichInfo(modelSchema.getDataSet().getId());
if (Objects.nonNull(chatConfigRichInfo)) {
chatConfigRichInfoList.add(chatConfigRichInfo);
}

View File

@@ -93,8 +93,8 @@ public class PluginServiceImpl implements PluginService {
if (StringUtils.isNotBlank(pluginQueryReq.getType())) {
queryWrapper.lambda().eq(PluginDO::getType, pluginQueryReq.getType());
}
if (StringUtils.isNotBlank(pluginQueryReq.getView())) {
queryWrapper.lambda().like(PluginDO::getView, pluginQueryReq.getView());
if (StringUtils.isNotBlank(pluginQueryReq.getDataSet())) {
queryWrapper.lambda().like(PluginDO::getDataSet, pluginQueryReq.getDataSet());
}
if (StringUtils.isNotBlank(pluginQueryReq.getParseMode())) {
queryWrapper.lambda().eq(PluginDO::getParseMode, pluginQueryReq.getParseMode());
@@ -180,8 +180,8 @@ public class PluginServiceImpl implements PluginService {
public Plugin convert(PluginDO pluginDO) {
Plugin plugin = new Plugin();
BeanUtils.copyProperties(pluginDO, plugin);
if (StringUtils.isNotBlank(pluginDO.getView())) {
plugin.setViewList(Arrays.stream(pluginDO.getView().split(","))
if (StringUtils.isNotBlank(pluginDO.getDataSet())) {
plugin.setDataSetList(Arrays.stream(pluginDO.getDataSet().split(","))
.map(Long::parseLong).collect(Collectors.toList()));
}
return plugin;
@@ -194,7 +194,7 @@ public class PluginServiceImpl implements PluginService {
pluginDO.setCreatedBy(user.getName());
pluginDO.setUpdatedAt(new Date());
pluginDO.setUpdatedBy(user.getName());
pluginDO.setView(StringUtils.join(plugin.getViewList(), ","));
pluginDO.setDataSet(StringUtils.join(plugin.getDataSetList(), ","));
return pluginDO;
}
@@ -202,7 +202,7 @@ public class PluginServiceImpl implements PluginService {
BeanUtils.copyProperties(plugin, pluginDO);
pluginDO.setUpdatedAt(new Date());
pluginDO.setUpdatedBy(user.getName());
pluginDO.setView(StringUtils.join(plugin.getViewList(), ","));
pluginDO.setDataSet(StringUtils.join(plugin.getDataSetList(), ","));
return pluginDO;
}

View File

@@ -1,12 +1,11 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
@@ -296,7 +295,7 @@ public class QueryServiceImpl implements QueryService {
similarQueryManager.saveSimilarQuery(SimilarQueryReq.builder().parseId(queryReq.getParseId())
.queryId(queryReq.getQueryId())
.agentId(chatQueryDO.getAgentId())
.viewId(parseInfo.getViewId())
.dataSetId(parseInfo.getDataSetId())
.queryText(queryReq.getQueryText()).build());
}
@@ -352,9 +351,9 @@ public class QueryServiceImpl implements QueryService {
}
QueryResult queryResult = semanticQuery.execute(user);
queryResult.setChatContext(semanticQuery.getParseInfo());
ViewSchema viewSchema = semanticSchema.getViewSchemaMap().get(parseInfo.getViewId());
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(parseInfo.getDataSetId());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, viewSchema, user);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
queryResult.setEntityInfo(entityInfo);
return queryResult;
}
@@ -422,8 +421,9 @@ public class QueryServiceImpl implements QueryService {
ChatParseDO chatParseDO = chatService.getParseInfo(queryId, parseId);
SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ViewSchema viewSchema = schemaService.getSemanticSchema().getViewSchemaMap().get(parseInfo.getViewId());
return semanticService.getEntityInfo(parseInfo, viewSchema, user);
DataSetSchema dataSetSchema =
schemaService.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
return semanticService.getEntityInfo(parseInfo, dataSetSchema, user);
}
private void updateDateInfo(QueryDataReq queryData, SemanticParseInfo parseInfo,
@@ -641,10 +641,10 @@ public class QueryServiceImpl implements QueryService {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
SchemaElement schemaElement = semanticSchema.getDimension(dimensionValueReq.getElementID());
Set<Long> detectViewIds = new HashSet<>();
detectViewIds.add(schemaElement.getView());
dimensionValueReq.setModelId(schemaElement.getView());
List<String> dimensionValues = getDimensionValues(dimensionValueReq, detectViewIds);
Set<Long> detectDataSetIds = new HashSet<>();
detectDataSetIds.add(schemaElement.getDataSet());
dimensionValueReq.setModelId(schemaElement.getDataSet());
List<String> dimensionValues = getDimensionValues(dimensionValueReq, detectDataSetIds);
// if the search results is null,search dimensionValue from database
if (CollectionUtils.isEmpty(dimensionValues)) {
semanticQueryResp = queryDatabase(dimensionValueReq, user);
@@ -668,14 +668,14 @@ public class QueryServiceImpl implements QueryService {
return semanticQueryResp;
}
private List<String> getDimensionValues(DimensionValueReq dimensionValueReq, Set<Long> viewIds) {
private List<String> getDimensionValues(DimensionValueReq dimensionValueReq, Set<Long> dataSetIds) {
//if value is null ,then search from NATURE_TO_VALUES
if (StringUtils.isBlank(dimensionValueReq.getValue())) {
return SearchService.getDimensionValue(dimensionValueReq);
}
//search from prefixSearch
List<HanlpMapResult> hanlpMapResultList = knowledgeService.prefixSearch(dimensionValueReq.getValue(),
2000, viewIds);
2000, dataSetIds);
HanlpHelper.transLetterOriginal(hanlpMapResultList);
return hanlpMapResultList.stream()
.filter(o -> {
@@ -700,7 +700,7 @@ public class QueryServiceImpl implements QueryService {
dateConf.setPeriod("DAY");
queryStructReq.setDateInfo(dateConf);
queryStructReq.setLimit(20L);
queryStructReq.setViewId(dimensionValueReq.getModelId());
queryStructReq.setDataSetId(dimensionValueReq.getModelId());
queryStructReq.setQueryType(QueryType.ID);
List<String> groups = new ArrayList<>();
groups.add(dimensionValueReq.getBizName());

View File

@@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.server.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.request.RecommendReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
@@ -48,7 +48,7 @@ public class RecommendServiceImpl implements RecommendService {
if (Objects.isNull(modelId)) {
return new RecommendResp();
}
ViewSchema modelSchema = semanticService.getViewSchema(modelId);
DataSetSchema modelSchema = semanticService.getDataSetSchema(modelId);
if (Objects.isNull(modelSchema)) {
return new RecommendResp();
}
@@ -80,7 +80,7 @@ public class RecommendServiceImpl implements RecommendService {
.limit(limit)
.map(dimSchemaDesc -> {
SchemaElement item = new SchemaElement();
item.setView(modelId);
item.setDataSet(modelId);
item.setName(dimSchemaDesc.getName());
item.setBizName(dimSchemaDesc.getBizName());
item.setId(dimSchemaDesc.getId());
@@ -94,7 +94,7 @@ public class RecommendServiceImpl implements RecommendService {
.limit(limit)
.map(metricSchemaDesc -> {
SchemaElement item = new SchemaElement();
item.setView(modelId);
item.setDataSet(modelId);
item.setName(metricSchemaDesc.getName());
item.setBizName(metricSchemaDesc.getBizName());
item.setId(metricSchemaDesc.getId());

View File

@@ -4,7 +4,7 @@ import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
@@ -28,13 +28,13 @@ public class SchemaService {
@Override
public SemanticSchema load(String key) {
log.info("load getDomainSchemaInfo cache [{}]", key);
return new SemanticSchema(semanticInterpreter.getViewSchema());
return new SemanticSchema(semanticInterpreter.getDataSetSchema());
}
}
);
public ViewSchema getViewSchema(Long id) {
return semanticInterpreter.getViewSchema(id, true);
public DataSetSchema getDataSetSchema(Long id) {
return semanticInterpreter.getDataSetSchema(id, true);
}
public SemanticSchema getSemanticSchema() {

View File

@@ -14,7 +14,7 @@ import com.tencent.supersonic.chat.core.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.DictWord;
import com.tencent.supersonic.headless.core.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.knowledge.ViewInfoStat;
import com.tencent.supersonic.headless.core.knowledge.DataSetInfoStat;
import com.tencent.supersonic.chat.core.mapper.MapperHelper;
import com.tencent.supersonic.chat.core.mapper.MatchText;
import com.tencent.supersonic.chat.core.mapper.ModelWithSemanticType;
@@ -91,18 +91,19 @@ public class SearchServiceImpl implements SearchService {
// 2.get meta info
SemanticSchema semanticSchemaDb = schemaService.getSemanticSchema();
List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics();
final Map<Long, String> modelToName = semanticSchemaDb.getViewIdToName();
final Map<Long, String> modelToName = semanticSchemaDb.getDataSetIdToName();
// 3.detect by segment
List<S2Term> originals = knowledgeService.getTerms(queryText);
log.info("hanlp parse result: {}", originals);
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> detectViewIds = mapperHelper.getViewIds(queryReq.getViewId(), agentService.getAgent(agentId));
Set<Long> detectDataSetIds = mapperHelper.getDataSetIds(queryReq.getDataSetId(),
agentService.getAgent(agentId));
QueryContext queryContext = new QueryContext();
BeanUtils.copyProperties(queryReq, queryContext);
Map<MatchText, List<HanlpMapResult>> regTextMap =
searchMatchStrategy.match(queryContext, originals, detectViewIds);
searchMatchStrategy.match(queryContext, originals, detectDataSetIds);
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));
// 4.get the most matching data
@@ -121,9 +122,9 @@ public class SearchServiceImpl implements SearchService {
log.info("searchTextEntry:{},queryReq:{}", searchTextEntry, queryReq);
Set<SearchResult> searchResults = new LinkedHashSet();
ViewInfoStat modelStat = NatureHelper.getViewStat(originals);
DataSetInfoStat modelStat = NatureHelper.getDataSetStat(originals);
List<Long> possibleModels = getPossibleModels(queryReq, originals, modelStat, queryReq.getViewId());
List<Long> possibleModels = getPossibleModels(queryReq, originals, modelStat, queryReq.getDataSetId());
// 5.1 priority dimension metric
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleModels), modelToName,
@@ -137,7 +138,7 @@ public class SearchServiceImpl implements SearchService {
for (Map.Entry<String, String> natureToNameEntry : natureToNameMap.entrySet()) {
Set<SearchResult> searchResultSet = searchDimensionValue(metricsDb, modelToName,
modelStat.getMetricViewCount(), existMetricAndDimension,
modelStat.getMetricDataSetCount(), existMetricAndDimension,
matchText, natureToNameMap, natureToNameEntry, queryReq.getQueryFilters());
searchResults.addAll(searchResultSet);
@@ -146,7 +147,7 @@ public class SearchServiceImpl implements SearchService {
}
private List<Long> getPossibleModels(QueryReq queryCtx, List<S2Term> originals,
ViewInfoStat modelStat, Long webModelId) {
DataSetInfoStat modelStat, Long webModelId) {
if (Objects.nonNull(webModelId) && webModelId > 0) {
List<Long> result = new ArrayList<>();
@@ -154,7 +155,7 @@ public class SearchServiceImpl implements SearchService {
return result;
}
List<Long> possibleModels = NatureHelper.selectPossibleViews(originals);
List<Long> possibleModels = NatureHelper.selectPossibleDataSets(originals);
Long contextModel = chatService.getContextModel(queryCtx.getChatId());
@@ -167,9 +168,9 @@ public class SearchServiceImpl implements SearchService {
return possibleModels;
}
private boolean nothingOrOnlyMetric(ViewInfoStat modelStat) {
return modelStat.getMetricViewCount() >= 0 && modelStat.getDimensionViewCount() <= 0
&& modelStat.getDimensionValueViewCount() <= 0 && modelStat.getViewCount() <= 0;
private boolean nothingOrOnlyMetric(DataSetInfoStat modelStat) {
return modelStat.getMetricDataSetCount() >= 0 && modelStat.getDimensionDataSetCount() <= 0
&& modelStat.getDimensionValueDataSetCount() <= 0 && modelStat.getDataSetCount() <= 0;
}
private boolean effectiveModel(Long contextModel) {
@@ -189,7 +190,7 @@ public class SearchServiceImpl implements SearchService {
String nature = natureToNameEntry.getKey();
String wordName = natureToNameEntry.getValue();
Long modelId = NatureHelper.getViewId(nature);
Long modelId = NatureHelper.getDataSetId(nature);
SchemaElementType schemaElementType = NatureHelper.convertToElementType(nature);
if (SchemaElementType.ENTITY.equals(schemaElementType)) {
@@ -266,7 +267,7 @@ public class SearchServiceImpl implements SearchService {
return Lists.newArrayList();
}
return metricsDb.stream()
.filter(mapDO -> Objects.nonNull(mapDO) && model.equals(mapDO.getView()))
.filter(mapDO -> Objects.nonNull(mapDO) && model.equals(mapDO.getDataSet()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.flatMap(entry -> {
List<String> result = new ArrayList<>();
@@ -290,7 +291,7 @@ public class SearchServiceImpl implements SearchService {
if (CollectionUtils.isEmpty(possibleModels)) {
return true;
}
Long model = NatureHelper.getViewId(nature);
Long model = NatureHelper.getDataSetId(nature);
return possibleModels.contains(model);
})
.map(nature -> {
@@ -313,7 +314,7 @@ public class SearchServiceImpl implements SearchService {
for (HanlpMapResult hanlpMapResult : hanlpMapResults) {
List<ModelWithSemanticType> dimensionMetricClassIds = hanlpMapResult.getNatures().stream()
.map(nature -> new ModelWithSemanticType(NatureHelper.getViewId(nature),
.map(nature -> new ModelWithSemanticType(NatureHelper.getDataSetId(nature),
NatureHelper.convertToElementType(nature)))
.filter(entry -> matchCondition(entry, possibleModels)).collect(Collectors.toList());

View File

@@ -24,7 +24,7 @@ public class WordService {
public List<DictWord> getAllDictWords() {
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
SemanticSchema semanticSchema = new SemanticSchema(semanticInterpreter.getViewSchema());
SemanticSchema semanticSchema = new SemanticSchema(semanticInterpreter.getDataSetSchema());
List<DictWord> words = new ArrayList<>();

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.server.util;
import static com.tencent.supersonic.common.pojo.Constants.ADMIN_LOWER;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
@@ -56,7 +56,7 @@ public class ChatConfigHelper {
return chatConfig;
}
public List<Long> generateAllDimIdList(ViewSchema modelSchema) {
public List<Long> generateAllDimIdList(DataSetSchema modelSchema) {
if (Objects.isNull(modelSchema) || CollectionUtils.isEmpty(modelSchema.getDimensions())) {
return new ArrayList<>();
}
@@ -65,7 +65,7 @@ public class ChatConfigHelper {
return new ArrayList<>(dimIdAndDescPair.keySet());
}
public List<Long> generateAllMetricIdList(ViewSchema modelSchema) {
public List<Long> generateAllMetricIdList(DataSetSchema modelSchema) {
if (Objects.isNull(modelSchema) || CollectionUtils.isEmpty(modelSchema.getMetrics())) {
return new ArrayList<>();
}

View File

@@ -30,8 +30,8 @@ class QueryReqBuilderTest {
void buildS2SQLReq() {
init();
QueryStructReq queryStructReq = new QueryStructReq();
queryStructReq.setViewId(1L);
queryStructReq.setViewName("内容库");
queryStructReq.setDataSetId(1L);
queryStructReq.setDataSetName("内容库");
queryStructReq.setQueryType(QueryType.METRIC);
Aggregator aggregator = new Aggregator();

View File

@@ -29,7 +29,7 @@ public class LoadRemoveService {
if (Objects.isNull(nature)) {
return false;
}
Long modelId = getViewId(nature);
Long modelId = getDataSetId(nature);
if (Objects.nonNull(modelId)) {
return !detectModelIds.contains(modelId);
}
@@ -47,7 +47,7 @@ public class LoadRemoveService {
return resultList;
}
public Long getViewId(String nature) {
public Long getDataSetId(String nature) {
try {
String[] split = nature.split(DictWordType.NATURE_SPILT);
if (split.length <= 1) {

View File

@@ -7,7 +7,7 @@ public enum TypeEnums {
TAG,
DOMAIN,
ENTITY,
VIEW,
DATASET,
MODEL,
UNKNOWN;

View File

@@ -4,8 +4,8 @@ import lombok.Data;
import java.util.List;
@Data
public class ViewDetail {
public class DataSetDetail {
private List<ViewModelConfig> viewModelConfigs;
private List<DataSetModelConfig> dataSetModelConfigs;
}

View File

@@ -10,7 +10,7 @@ import java.util.List;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ViewModelConfig {
public class DataSetModelConfig {
private Long id;
@@ -20,7 +20,7 @@ public class ViewModelConfig {
private List<Long> dimensions = Lists.newArrayList();
public ViewModelConfig(Long id, List<Long> dimensions, List<Long> metrics) {
public DataSetModelConfig(Long id, List<Long> dimensions, List<Long> metrics) {
this.id = id;
this.metrics = metrics;
this.dimensions = dimensions;

Some files were not shown because too many files have changed in this diff Show More