mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 22:46:49 +00:00
(improvement) Move out the datasource and merge the datasource with the model, and adapt the chat module (#423)
Co-authored-by: jolunoluo <jolunoluo@tencent.com>
This commit is contained in:
@@ -9,8 +9,8 @@ import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
@@ -14,8 +19,8 @@ public class ModelSchema {
|
||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||
private Set<SchemaElement> tags = new HashSet<>();
|
||||
@Deprecated
|
||||
private SchemaElement entity = new SchemaElement();
|
||||
private List<ModelRela> modelRelas = new ArrayList<>();
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
@@ -75,4 +80,16 @@ public class ModelSchema {
|
||||
}
|
||||
}
|
||||
|
||||
public Set<Long> getModelClusterSet() {
|
||||
if (CollectionUtils.isEmpty(modelRelas)) {
|
||||
return Sets.newHashSet();
|
||||
}
|
||||
Set<Long> modelClusterSet = new HashSet<>();
|
||||
modelRelas.forEach(modelRela -> {
|
||||
modelClusterSet.add(modelRela.getToModelId());
|
||||
modelClusterSet.add(modelRela.getFromModelId());
|
||||
});
|
||||
return modelClusterSet;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ public class QueryContext {
|
||||
private QueryReq request;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
||||
|
||||
public QueryContext(QueryReq request) {
|
||||
this.request = request;
|
||||
|
||||
@@ -6,7 +6,7 @@ public enum SchemaElementType {
|
||||
DIMENSION,
|
||||
VALUE,
|
||||
ENTITY,
|
||||
TAG,
|
||||
ID,
|
||||
DATE,
|
||||
TAG
|
||||
DATE
|
||||
}
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.clickhouse.client.internal.apache.commons.compress.utils.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
public class SchemaModelClusterMapInfo {
|
||||
|
||||
private Map<String, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
||||
|
||||
public Set<String> getMatchedModelClusters() {
|
||||
return modelElementMatches.keySet();
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> getMatchedElements(Long modelId) {
|
||||
for (String key : modelElementMatches.keySet()) {
|
||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||
return modelElementMatches.get(key);
|
||||
}
|
||||
}
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> getMatchedElements(String modelCluster) {
|
||||
return modelElementMatches.get(modelCluster);
|
||||
}
|
||||
|
||||
public Map<String, List<SchemaElementMatch>> getModelElementMatches() {
|
||||
return modelElementMatches;
|
||||
}
|
||||
|
||||
public Map<String, List<SchemaElementMatch>> getElementMatchesByModelIds(Set<Long> modelIds) {
|
||||
if (CollectionUtils.isEmpty(modelIds)) {
|
||||
return modelElementMatches;
|
||||
}
|
||||
Map<String, List<SchemaElementMatch>> modelElementMatchesFiltered = new HashMap<>();
|
||||
for (String key : modelElementMatches.keySet()) {
|
||||
for (Long modelId : modelIds) {
|
||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||
modelElementMatchesFiltered.put(key, modelElementMatches.get(key));
|
||||
}
|
||||
}
|
||||
}
|
||||
return modelElementMatchesFiltered;
|
||||
}
|
||||
|
||||
public void setModelElementMatches(Map<String, List<SchemaElementMatch>> modelElementMatches) {
|
||||
this.modelElementMatches = modelElementMatches;
|
||||
}
|
||||
|
||||
public void setMatchedElements(String modelCluster, List<SchemaElementMatch> elementMatches) {
|
||||
modelElementMatches.put(modelCluster, elementMatches);
|
||||
}
|
||||
}
|
||||
@@ -5,9 +5,13 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -16,15 +20,13 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class SemanticParseInfo {
|
||||
|
||||
private Integer id;
|
||||
private String queryMode;
|
||||
private SchemaElement model;
|
||||
private ModelCluster model = new ModelCluster();
|
||||
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
||||
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
||||
private SchemaElement entity;
|
||||
@@ -42,12 +44,18 @@ public class SemanticParseInfo {
|
||||
private SqlInfo sqlInfo = new SqlInfo();
|
||||
private QueryType queryType = QueryType.OTHER;
|
||||
|
||||
public Long getModelId() {
|
||||
return model != null ? model.getId() : 0L;
|
||||
public String getModelClusterKey() {
|
||||
if (model == null) {
|
||||
return "";
|
||||
}
|
||||
return model.getKey();
|
||||
}
|
||||
|
||||
public String getModelName() {
|
||||
return model != null ? model.getName() : "null";
|
||||
if (model == null) {
|
||||
return "";
|
||||
}
|
||||
return model.getName();
|
||||
}
|
||||
|
||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||
@@ -78,4 +86,26 @@ public class SemanticParseInfo {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
private Map<Long, Integer> getModelElementCountMap() {
|
||||
Map<Long, Integer> elementCountMap = new HashMap<>();
|
||||
elementMatches.forEach(element -> {
|
||||
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
|
||||
elementCountMap.put(element.getElement().getModel(), count + 1);
|
||||
});
|
||||
return elementCountMap;
|
||||
}
|
||||
|
||||
public Long getModelId() {
|
||||
Map<Long, Integer> elementCountMap = getModelElementCountMap();
|
||||
Long modelId = -1L;
|
||||
int maxCnt = 0;
|
||||
for (Long model : elementCountMap.keySet()) {
|
||||
if (elementCountMap.get(model) > maxCnt) {
|
||||
maxCnt = elementCountMap.get(model);
|
||||
modelId = model;
|
||||
}
|
||||
}
|
||||
return modelId;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class SemanticSchema implements Serializable {
|
||||
@@ -18,6 +23,64 @@ public class SemanticSchema implements Serializable {
|
||||
modelSchemaList.add(schema);
|
||||
}
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
|
||||
switch (elementType) {
|
||||
case ENTITY:
|
||||
element = getElementsById(elementID, getEntities());
|
||||
break;
|
||||
case MODEL:
|
||||
element = getElementsById(elementID, getModels());
|
||||
break;
|
||||
case METRIC:
|
||||
element = getElementsById(elementID, getMetrics());
|
||||
break;
|
||||
case DIMENSION:
|
||||
element = getElementsById(elementID, getDimensions());
|
||||
break;
|
||||
case VALUE:
|
||||
element = getElementsById(elementID, getDimensionValues());
|
||||
break;
|
||||
default:
|
||||
}
|
||||
|
||||
if (element.isPresent()) {
|
||||
return element.get();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getElementByName(SchemaElementType elementType, String name) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
|
||||
switch (elementType) {
|
||||
case ENTITY:
|
||||
element = getElementsByName(name, getEntities());
|
||||
break;
|
||||
case MODEL:
|
||||
element = getElementsByName(name, getModels());
|
||||
break;
|
||||
case METRIC:
|
||||
element = getElementsByName(name, getMetrics());
|
||||
break;
|
||||
case DIMENSION:
|
||||
element = getElementsByName(name, getDimensions());
|
||||
break;
|
||||
case VALUE:
|
||||
element = getElementsByName(name, getDimensionValues());
|
||||
break;
|
||||
default:
|
||||
}
|
||||
|
||||
if (element.isPresent()) {
|
||||
return element.get();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public Map<Long, String> getModelIdToName() {
|
||||
return modelSchemaList.stream()
|
||||
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
|
||||
@@ -35,9 +98,21 @@ public class SemanticSchema implements Serializable {
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDimensions(Long modelId) {
|
||||
public List<SchemaElement> getDimensions(Set<Long> modelIds) {
|
||||
List<SchemaElement> dimensions = getDimensions();
|
||||
return getElementsByModelId(modelId, dimensions);
|
||||
return getElementsByModelId(modelIds, dimensions);
|
||||
}
|
||||
|
||||
public SchemaElement getDimensions(Long id) {
|
||||
List<SchemaElement> dimensions = getDimensions();
|
||||
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
|
||||
return dimension.orElse(null);
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTags() {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getMetrics() {
|
||||
@@ -46,21 +121,9 @@ public class SemanticSchema implements Serializable {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getMetrics(Long modelId) {
|
||||
public List<SchemaElement> getMetrics(Set<Long> modelIds) {
|
||||
List<SchemaElement> metrics = getMetrics();
|
||||
return getElementsByModelId(modelId, metrics);
|
||||
}
|
||||
|
||||
private List<SchemaElement> getElementsByModelId(Long modelId, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public List<SchemaElement> getModels() {
|
||||
List<SchemaElement> models = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
||||
return models;
|
||||
return getElementsByModelId(modelIds, metrics);
|
||||
}
|
||||
|
||||
public List<SchemaElement> getEntities() {
|
||||
@@ -69,11 +132,43 @@ public class SemanticSchema implements Serializable {
|
||||
return entities;
|
||||
}
|
||||
|
||||
public Map<String, String> getBizNameToName(Long modelId) {
|
||||
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private Optional<SchemaElement> getElementsById(Long id, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> id.equals(schemaElement.getId()))
|
||||
.findFirst();
|
||||
}
|
||||
|
||||
private Optional<SchemaElement> getElementsByName(String name, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> name.equals(schemaElement.getName()))
|
||||
.findFirst();
|
||||
}
|
||||
|
||||
public List<SchemaElement> getModels() {
|
||||
List<SchemaElement> models = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
||||
return models;
|
||||
}
|
||||
|
||||
public Map<String, String> getBizNameToName(Set<Long> modelIds) {
|
||||
List<SchemaElement> allElements = new ArrayList<>();
|
||||
allElements.addAll(getDimensions(modelId));
|
||||
allElements.addAll(getMetrics(modelId));
|
||||
allElements.addAll(getDimensions(modelIds));
|
||||
allElements.addAll(getMetrics(modelIds));
|
||||
return allElements.stream()
|
||||
.collect(Collectors.toMap(a -> a.getBizName(), a -> a.getName(), (k1, k2) -> k1));
|
||||
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
public Map<Long, ModelSchema> getModelSchemaMap() {
|
||||
if (CollectionUtils.isEmpty(modelSchemaList)) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
|
||||
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import lombok.Data;
|
||||
public class QueryReq {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Long modelId = 0L;
|
||||
private Long modelId;
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
|
||||
@@ -18,7 +18,7 @@ public class SolvedQueryReq {
|
||||
|
||||
private String queryText;
|
||||
|
||||
private Long modelId;
|
||||
private String modelId;
|
||||
|
||||
private Integer agentId;
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ModelInfo extends DataInfo implements Serializable {
|
||||
|
||||
private List<String> words;
|
||||
private String primaryEntityName;
|
||||
private String primaryEntityBizName;
|
||||
private String primaryKey;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user