(improvement) Move out the datasource and merge the datasource with the model, and adapt the chat module (#423)

Co-authored-by: jolunoluo <jolunoluo@tencent.com>
This commit is contained in:
jipeli
2023-11-27 11:05:24 +08:00
committed by GitHub
parent 0534053ff9
commit 27bb1b322e
190 changed files with 3900 additions and 10561 deletions

View File

@@ -9,8 +9,8 @@ import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;

View File

@@ -1,8 +1,13 @@
package com.tencent.supersonic.chat.api.pojo;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.pojo.ModelRela;
import lombok.Data;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
@@ -14,8 +19,8 @@ public class ModelSchema {
private Set<SchemaElement> dimensions = new HashSet<>();
private Set<SchemaElement> dimensionValues = new HashSet<>();
private Set<SchemaElement> tags = new HashSet<>();
@Deprecated
private SchemaElement entity = new SchemaElement();
private List<ModelRela> modelRelas = new ArrayList<>();
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
Optional<SchemaElement> element = Optional.empty();
@@ -75,4 +80,16 @@ public class ModelSchema {
}
}
public Set<Long> getModelClusterSet() {
if (CollectionUtils.isEmpty(modelRelas)) {
return Sets.newHashSet();
}
Set<Long> modelClusterSet = new HashSet<>();
modelRelas.forEach(modelRela -> {
modelClusterSet.add(modelRela.getToModelId());
modelClusterSet.add(modelRela.getFromModelId());
});
return modelClusterSet;
}
}

View File

@@ -15,6 +15,7 @@ public class QueryContext {
private QueryReq request;
private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
public QueryContext(QueryReq request) {
this.request = request;

View File

@@ -6,7 +6,7 @@ public enum SchemaElementType {
DIMENSION,
VALUE,
ENTITY,
TAG,
ID,
DATE,
TAG
DATE
}

View File

@@ -0,0 +1,61 @@
package com.tencent.supersonic.chat.api.pojo;
import com.clickhouse.client.internal.apache.commons.compress.utils.Lists;
import com.tencent.supersonic.common.pojo.ModelCluster;
import lombok.Data;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
@Data
public class SchemaModelClusterMapInfo {
private Map<String, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
public Set<String> getMatchedModelClusters() {
return modelElementMatches.keySet();
}
public List<SchemaElementMatch> getMatchedElements(Long modelId) {
for (String key : modelElementMatches.keySet()) {
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
return modelElementMatches.get(key);
}
}
return Lists.newArrayList();
}
public List<SchemaElementMatch> getMatchedElements(String modelCluster) {
return modelElementMatches.get(modelCluster);
}
public Map<String, List<SchemaElementMatch>> getModelElementMatches() {
return modelElementMatches;
}
public Map<String, List<SchemaElementMatch>> getElementMatchesByModelIds(Set<Long> modelIds) {
if (CollectionUtils.isEmpty(modelIds)) {
return modelElementMatches;
}
Map<String, List<SchemaElementMatch>> modelElementMatchesFiltered = new HashMap<>();
for (String key : modelElementMatches.keySet()) {
for (Long modelId : modelIds) {
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
modelElementMatchesFiltered.put(key, modelElementMatches.get(key));
}
}
}
return modelElementMatchesFiltered;
}
public void setModelElementMatches(Map<String, List<SchemaElementMatch>> modelElementMatches) {
this.modelElementMatches = modelElementMatches;
}
public void setMatchedElements(String modelCluster, List<SchemaElementMatch> elementMatches) {
modelElementMatches.put(modelCluster, elementMatches);
}
}

View File

@@ -5,9 +5,13 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterType;
import lombok.Data;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
@@ -16,15 +20,13 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import com.tencent.supersonic.common.pojo.enums.FilterType;
import lombok.Data;
@Data
public class SemanticParseInfo {
private Integer id;
private String queryMode;
private SchemaElement model;
private ModelCluster model = new ModelCluster();
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
private Set<SchemaElement> dimensions = new LinkedHashSet();
private SchemaElement entity;
@@ -42,12 +44,18 @@ public class SemanticParseInfo {
private SqlInfo sqlInfo = new SqlInfo();
private QueryType queryType = QueryType.OTHER;
public Long getModelId() {
return model != null ? model.getId() : 0L;
public String getModelClusterKey() {
if (model == null) {
return "";
}
return model.getKey();
}
public String getModelName() {
return model != null ? model.getName() : "null";
if (model == null) {
return "";
}
return model.getName();
}
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
@@ -78,4 +86,26 @@ public class SemanticParseInfo {
return metrics;
}
private Map<Long, Integer> getModelElementCountMap() {
Map<Long, Integer> elementCountMap = new HashMap<>();
elementMatches.forEach(element -> {
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
elementCountMap.put(element.getElement().getModel(), count + 1);
});
return elementCountMap;
}
public Long getModelId() {
Map<Long, Integer> elementCountMap = getModelElementCountMap();
Long modelId = -1L;
int maxCnt = 0;
for (Long model : elementCountMap.keySet()) {
if (elementCountMap.get(model) > maxCnt) {
maxCnt = elementCountMap.get(model);
modelId = model;
}
}
return modelId;
}
}

View File

@@ -1,9 +1,14 @@
package com.tencent.supersonic.chat.api.pojo;
import org.springframework.util.CollectionUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
public class SemanticSchema implements Serializable {
@@ -18,6 +23,64 @@ public class SemanticSchema implements Serializable {
modelSchemaList.add(schema);
}
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
Optional<SchemaElement> element = Optional.empty();
switch (elementType) {
case ENTITY:
element = getElementsById(elementID, getEntities());
break;
case MODEL:
element = getElementsById(elementID, getModels());
break;
case METRIC:
element = getElementsById(elementID, getMetrics());
break;
case DIMENSION:
element = getElementsById(elementID, getDimensions());
break;
case VALUE:
element = getElementsById(elementID, getDimensionValues());
break;
default:
}
if (element.isPresent()) {
return element.get();
} else {
return null;
}
}
public SchemaElement getElementByName(SchemaElementType elementType, String name) {
Optional<SchemaElement> element = Optional.empty();
switch (elementType) {
case ENTITY:
element = getElementsByName(name, getEntities());
break;
case MODEL:
element = getElementsByName(name, getModels());
break;
case METRIC:
element = getElementsByName(name, getMetrics());
break;
case DIMENSION:
element = getElementsByName(name, getDimensions());
break;
case VALUE:
element = getElementsByName(name, getDimensionValues());
break;
default:
}
if (element.isPresent()) {
return element.get();
} else {
return null;
}
}
public Map<Long, String> getModelIdToName() {
return modelSchemaList.stream()
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
@@ -35,9 +98,21 @@ public class SemanticSchema implements Serializable {
return dimensions;
}
public List<SchemaElement> getDimensions(Long modelId) {
public List<SchemaElement> getDimensions(Set<Long> modelIds) {
List<SchemaElement> dimensions = getDimensions();
return getElementsByModelId(modelId, dimensions);
return getElementsByModelId(modelIds, dimensions);
}
public SchemaElement getDimensions(Long id) {
List<SchemaElement> dimensions = getDimensions();
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
return dimension.orElse(null);
}
public List<SchemaElement> getTags() {
List<SchemaElement> tags = new ArrayList<>();
modelSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
return tags;
}
public List<SchemaElement> getMetrics() {
@@ -46,21 +121,9 @@ public class SemanticSchema implements Serializable {
return metrics;
}
public List<SchemaElement> getMetrics(Long modelId) {
public List<SchemaElement> getMetrics(Set<Long> modelIds) {
List<SchemaElement> metrics = getMetrics();
return getElementsByModelId(modelId, metrics);
}
private List<SchemaElement> getElementsByModelId(Long modelId, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.collect(Collectors.toList());
}
public List<SchemaElement> getModels() {
List<SchemaElement> models = new ArrayList<>();
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
return models;
return getElementsByModelId(modelIds, metrics);
}
public List<SchemaElement> getEntities() {
@@ -69,11 +132,43 @@ public class SemanticSchema implements Serializable {
return entities;
}
public Map<String, String> getBizNameToName(Long modelId) {
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
.collect(Collectors.toList());
}
private Optional<SchemaElement> getElementsById(Long id, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> id.equals(schemaElement.getId()))
.findFirst();
}
private Optional<SchemaElement> getElementsByName(String name, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> name.equals(schemaElement.getName()))
.findFirst();
}
public List<SchemaElement> getModels() {
List<SchemaElement> models = new ArrayList<>();
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
return models;
}
public Map<String, String> getBizNameToName(Set<Long> modelIds) {
List<SchemaElement> allElements = new ArrayList<>();
allElements.addAll(getDimensions(modelId));
allElements.addAll(getMetrics(modelId));
allElements.addAll(getDimensions(modelIds));
allElements.addAll(getMetrics(modelIds));
return allElements.stream()
.collect(Collectors.toMap(a -> a.getBizName(), a -> a.getName(), (k1, k2) -> k1));
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
}
public Map<Long, ModelSchema> getModelSchemaMap() {
if (CollectionUtils.isEmpty(modelSchemaList)) {
return new HashMap<>();
}
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
}
}

View File

@@ -7,7 +7,7 @@ import lombok.Data;
public class QueryReq {
private String queryText;
private Integer chatId;
private Long modelId = 0L;
private Long modelId;
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;

View File

@@ -18,7 +18,7 @@ public class SolvedQueryReq {
private String queryText;
private Long modelId;
private String modelId;
private Integer agentId;

View File

@@ -1,13 +1,13 @@
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.Data;
import java.io.Serializable;
import java.util.List;
import lombok.Data;
@Data
public class ModelInfo extends DataInfo implements Serializable {
private List<String> words;
private String primaryEntityName;
private String primaryEntityBizName;
private String primaryKey;
}