mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:00:23 +00:00
(improvement)(headless)(chat) Add views and adapt chat and headless (#700)
* (improvement)(headless)(chat) Add views and adapt chat and headless --------- Co-authored-by: jolunoluo
This commit is contained in:
@@ -17,7 +17,7 @@ import java.util.List;
|
||||
@NoArgsConstructor
|
||||
public class SchemaElement implements Serializable {
|
||||
|
||||
private Long model;
|
||||
private Long view;
|
||||
private Long id;
|
||||
private String name;
|
||||
private String bizName;
|
||||
@@ -40,7 +40,7 @@ public class SchemaElement implements Serializable {
|
||||
return false;
|
||||
}
|
||||
SchemaElement schemaElement = (SchemaElement) o;
|
||||
return Objects.equal(model, schemaElement.model) && Objects.equal(id,
|
||||
return Objects.equal(view, schemaElement.view) && Objects.equal(id,
|
||||
schemaElement.id) && Objects.equal(name, schemaElement.name)
|
||||
&& Objects.equal(bizName, schemaElement.bizName)
|
||||
&& Objects.equal(type, schemaElement.type);
|
||||
@@ -48,7 +48,7 @@ public class SchemaElement implements Serializable {
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(model, id, name, bizName, type);
|
||||
return Objects.hashCode(view, id, name, bizName, type);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
public enum SchemaElementType {
|
||||
MODEL,
|
||||
VIEW,
|
||||
METRIC,
|
||||
DIMENSION,
|
||||
VALUE,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -7,25 +9,25 @@ import java.util.Set;
|
||||
|
||||
public class SchemaMapInfo {
|
||||
|
||||
private Map<Long, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
||||
private Map<Long, List<SchemaElementMatch>> viewElementMatches = new HashMap<>();
|
||||
|
||||
public Set<Long> getMatchedModels() {
|
||||
return modelElementMatches.keySet();
|
||||
public Set<Long> getMatchedViewInfos() {
|
||||
return viewElementMatches.keySet();
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> getMatchedElements(Long model) {
|
||||
return modelElementMatches.get(model);
|
||||
public List<SchemaElementMatch> getMatchedElements(Long view) {
|
||||
return viewElementMatches.getOrDefault(view, Lists.newArrayList());
|
||||
}
|
||||
|
||||
public Map<Long, List<SchemaElementMatch>> getModelElementMatches() {
|
||||
return modelElementMatches;
|
||||
public Map<Long, List<SchemaElementMatch>> getViewElementMatches() {
|
||||
return viewElementMatches;
|
||||
}
|
||||
|
||||
public void setModelElementMatches(Map<Long, List<SchemaElementMatch>> modelElementMatches) {
|
||||
this.modelElementMatches = modelElementMatches;
|
||||
public void setViewElementMatches(Map<Long, List<SchemaElementMatch>> viewElementMatches) {
|
||||
this.viewElementMatches = viewElementMatches;
|
||||
}
|
||||
|
||||
public void setMatchedElements(Long model, List<SchemaElementMatch> elementMatches) {
|
||||
modelElementMatches.put(model, elementMatches);
|
||||
public void setMatchedElements(Long view, List<SchemaElementMatch> elementMatches) {
|
||||
viewElementMatches.put(view, elementMatches);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.clickhouse.client.internal.apache.commons.compress.utils.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
public class SchemaModelClusterMapInfo {
|
||||
|
||||
private Map<String, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
|
||||
|
||||
public Set<String> getMatchedModelClusters() {
|
||||
return modelElementMatches.keySet();
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> getMatchedElements(Long modelId) {
|
||||
for (String key : modelElementMatches.keySet()) {
|
||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||
return modelElementMatches.get(key);
|
||||
}
|
||||
}
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> getMatchedElements(String modelCluster) {
|
||||
return modelElementMatches.get(modelCluster);
|
||||
}
|
||||
|
||||
public Map<String, List<SchemaElementMatch>> getModelElementMatches() {
|
||||
return modelElementMatches;
|
||||
}
|
||||
|
||||
public Map<String, List<SchemaElementMatch>> getElementMatchesByModelIds(Set<Long> modelIds) {
|
||||
if (CollectionUtils.isEmpty(modelIds)) {
|
||||
return modelElementMatches;
|
||||
}
|
||||
Map<String, List<SchemaElementMatch>> modelElementMatchesFiltered = new HashMap<>();
|
||||
for (String key : modelElementMatches.keySet()) {
|
||||
for (Long modelId : modelIds) {
|
||||
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
|
||||
modelElementMatchesFiltered.put(key, modelElementMatches.get(key));
|
||||
}
|
||||
}
|
||||
}
|
||||
return modelElementMatchesFiltered;
|
||||
}
|
||||
|
||||
public void setModelElementMatches(Map<String, List<SchemaElementMatch>> modelElementMatches) {
|
||||
this.modelElementMatches = modelElementMatches;
|
||||
}
|
||||
|
||||
public void setMatchedElements(String modelCluster, List<SchemaElementMatch> elementMatches) {
|
||||
modelElementMatches.put(modelCluster, elementMatches);
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
@@ -26,7 +25,7 @@ public class SemanticParseInfo {
|
||||
|
||||
private Integer id;
|
||||
private String queryMode;
|
||||
private ModelCluster model = new ModelCluster();
|
||||
private SchemaElement view;
|
||||
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
|
||||
private Set<SchemaElement> dimensions = new LinkedHashSet();
|
||||
private SchemaElement entity;
|
||||
@@ -44,20 +43,6 @@ public class SemanticParseInfo {
|
||||
private SqlInfo sqlInfo = new SqlInfo();
|
||||
private QueryType queryType = QueryType.ID;
|
||||
|
||||
public String getModelClusterKey() {
|
||||
if (model == null) {
|
||||
return "";
|
||||
}
|
||||
return model.getKey();
|
||||
}
|
||||
|
||||
public String getModelName() {
|
||||
if (model == null) {
|
||||
return "";
|
||||
}
|
||||
return model.getName();
|
||||
}
|
||||
|
||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||
|
||||
@Override
|
||||
@@ -86,27 +71,11 @@ public class SemanticParseInfo {
|
||||
return metrics;
|
||||
}
|
||||
|
||||
private Map<Long, Integer> getModelElementCountMap() {
|
||||
Map<Long, Integer> elementCountMap = new HashMap<>();
|
||||
elementMatches.stream().filter(element -> element.getElement().getModel() != null)
|
||||
.forEach(element -> {
|
||||
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
|
||||
elementCountMap.put(element.getElement().getModel(), count + 1);
|
||||
});
|
||||
return elementCountMap;
|
||||
}
|
||||
|
||||
public Long getModelId() {
|
||||
Map<Long, Integer> elementCountMap = getModelElementCountMap();
|
||||
Long modelId = -1L;
|
||||
int maxCnt = 0;
|
||||
for (Long model : elementCountMap.keySet()) {
|
||||
if (elementCountMap.get(model) > maxCnt) {
|
||||
maxCnt = elementCountMap.get(model);
|
||||
modelId = model;
|
||||
}
|
||||
public Long getViewId() {
|
||||
if (view == null) {
|
||||
return null;
|
||||
}
|
||||
return modelId;
|
||||
return view.getView();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
@@ -7,20 +9,18 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
public class SemanticSchema implements Serializable {
|
||||
|
||||
private List<ModelSchema> modelSchemaList;
|
||||
private List<ViewSchema> viewSchemaList;
|
||||
|
||||
public SemanticSchema(List<ModelSchema> modelSchemaList) {
|
||||
this.modelSchemaList = modelSchemaList;
|
||||
public SemanticSchema(List<ViewSchema> viewSchemaList) {
|
||||
this.viewSchemaList = viewSchemaList;
|
||||
}
|
||||
|
||||
public void add(ModelSchema schema) {
|
||||
modelSchemaList.add(schema);
|
||||
public void add(ViewSchema schema) {
|
||||
viewSchemaList.add(schema);
|
||||
}
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||
@@ -30,8 +30,8 @@ public class SemanticSchema implements Serializable {
|
||||
case ENTITY:
|
||||
element = getElementsById(elementID, getEntities());
|
||||
break;
|
||||
case MODEL:
|
||||
element = getElementsById(elementID, getModels());
|
||||
case VIEW:
|
||||
element = getElementsById(elementID, getViews());
|
||||
break;
|
||||
case METRIC:
|
||||
element = getElementsById(elementID, getMetrics());
|
||||
@@ -59,8 +59,8 @@ public class SemanticSchema implements Serializable {
|
||||
case ENTITY:
|
||||
element = getElementsByNameOrAlias(name, getEntities());
|
||||
break;
|
||||
case MODEL:
|
||||
element = getElementsByNameOrAlias(name, getModels());
|
||||
case VIEW:
|
||||
element = getElementsByNameOrAlias(name, getViews());
|
||||
break;
|
||||
case METRIC:
|
||||
element = getElementsByNameOrAlias(name, getMetrics());
|
||||
@@ -81,29 +81,29 @@ public class SemanticSchema implements Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
public Map<Long, String> getModelIdToName() {
|
||||
return modelSchemaList.stream()
|
||||
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
|
||||
public Map<Long, String> getViewIdToName() {
|
||||
return viewSchemaList.stream()
|
||||
.collect(Collectors.toMap(a -> a.getView().getId(), a -> a.getView().getName(), (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDimensionValues() {
|
||||
List<SchemaElement> dimensionValues = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
||||
viewSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
|
||||
return dimensionValues;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDimensions() {
|
||||
List<SchemaElement> dimensions = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
|
||||
viewSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getDimensions(Set<Long> modelIds) {
|
||||
public List<SchemaElement> getDimensions(Long viewId) {
|
||||
List<SchemaElement> dimensions = getDimensions();
|
||||
return getElementsByModelId(modelIds, dimensions);
|
||||
return getElementsByViewId(viewId, dimensions);
|
||||
}
|
||||
|
||||
public SchemaElement getDimensions(Long id) {
|
||||
public SchemaElement getDimension(Long id) {
|
||||
List<SchemaElement> dimensions = getDimensions();
|
||||
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
|
||||
return dimension.orElse(null);
|
||||
@@ -111,43 +111,43 @@ public class SemanticSchema implements Serializable {
|
||||
|
||||
public List<SchemaElement> getTags() {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||
viewSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getTags(Set<Long> modelIds) {
|
||||
public List<SchemaElement> getTags(Long viewId) {
|
||||
List<SchemaElement> tags = new ArrayList<>();
|
||||
modelSchemaList.stream().filter(schemaElement ->
|
||||
modelIds.contains(schemaElement.getModel().getModel()))
|
||||
viewSchemaList.stream().filter(schemaElement ->
|
||||
viewId.equals(schemaElement.getView().getView()))
|
||||
.forEach(d -> tags.addAll(d.getTags()));
|
||||
return tags;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getMetrics() {
|
||||
List<SchemaElement> metrics = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
||||
viewSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
||||
return metrics;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getMetrics(Set<Long> modelIds) {
|
||||
public List<SchemaElement> getMetrics(Long viewId) {
|
||||
List<SchemaElement> metrics = getMetrics();
|
||||
return getElementsByModelId(modelIds, metrics);
|
||||
return getElementsByViewId(viewId, metrics);
|
||||
}
|
||||
|
||||
public List<SchemaElement> getEntities() {
|
||||
List<SchemaElement> entities = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||
viewSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
|
||||
return entities;
|
||||
}
|
||||
|
||||
public List<SchemaElement> getEntities(Set<Long> modelIds) {
|
||||
public List<SchemaElement> getEntities(Long viewId) {
|
||||
List<SchemaElement> entities = getEntities();
|
||||
return getElementsByModelId(modelIds, entities);
|
||||
return getElementsByViewId(viewId, entities);
|
||||
}
|
||||
|
||||
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
|
||||
private List<SchemaElement> getElementsByViewId(Long viewId, List<SchemaElement> elements) {
|
||||
return elements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.filter(schemaElement -> viewId.equals(schemaElement.getView()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -165,25 +165,30 @@ public class SemanticSchema implements Serializable {
|
||||
).findFirst();
|
||||
}
|
||||
|
||||
public List<SchemaElement> getModels() {
|
||||
List<SchemaElement> models = new ArrayList<>();
|
||||
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
||||
return models;
|
||||
public SchemaElement getView(Long viewId) {
|
||||
List<SchemaElement> views = getViews();
|
||||
return getElementsById(viewId, views).orElse(null);
|
||||
}
|
||||
|
||||
public Map<String, String> getBizNameToName(Set<Long> modelIds) {
|
||||
public List<SchemaElement> getViews() {
|
||||
List<SchemaElement> views = new ArrayList<>();
|
||||
viewSchemaList.stream().forEach(d -> views.add(d.getView()));
|
||||
return views;
|
||||
}
|
||||
|
||||
public Map<String, String> getBizNameToName(Long viewId) {
|
||||
List<SchemaElement> allElements = new ArrayList<>();
|
||||
allElements.addAll(getDimensions(modelIds));
|
||||
allElements.addAll(getMetrics(modelIds));
|
||||
allElements.addAll(getDimensions(viewId));
|
||||
allElements.addAll(getMetrics(viewId));
|
||||
return allElements.stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
public Map<Long, ModelSchema> getModelSchemaMap() {
|
||||
if (CollectionUtils.isEmpty(modelSchemaList)) {
|
||||
public Map<Long, ViewSchema> getViewSchemaMap() {
|
||||
if (CollectionUtils.isEmpty(viewSchemaList)) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
|
||||
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
|
||||
return viewSchemaList.stream().collect(Collectors.toMap(viewSchema
|
||||
-> viewSchema.getView().getView(), viewSchema -> viewSchema));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
package com.tencent.supersonic.chat.api.pojo;
|
||||
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
public class ModelSchema {
|
||||
public class ViewSchema {
|
||||
|
||||
private SchemaElement model;
|
||||
private SchemaElement view;
|
||||
private Set<SchemaElement> metrics = new HashSet<>();
|
||||
private Set<SchemaElement> dimensions = new HashSet<>();
|
||||
private Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||
private Set<SchemaElement> tags = new HashSet<>();
|
||||
private SchemaElement entity = new SchemaElement();
|
||||
private List<ModelRela> modelRelas = new ArrayList<>();
|
||||
private QueryConfig queryConfig;
|
||||
|
||||
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
|
||||
Optional<SchemaElement> element = Optional.empty();
|
||||
@@ -29,8 +27,8 @@ public class ModelSchema {
|
||||
case ENTITY:
|
||||
element = Optional.ofNullable(entity);
|
||||
break;
|
||||
case MODEL:
|
||||
element = Optional.of(model);
|
||||
case VIEW:
|
||||
element = Optional.of(view);
|
||||
break;
|
||||
case METRIC:
|
||||
element = metrics.stream().filter(e -> e.getId() == elementID).findFirst();
|
||||
@@ -61,8 +59,8 @@ public class ModelSchema {
|
||||
case ENTITY:
|
||||
element = Optional.ofNullable(entity);
|
||||
break;
|
||||
case MODEL:
|
||||
element = Optional.of(model);
|
||||
case VIEW:
|
||||
element = Optional.of(view);
|
||||
break;
|
||||
case METRIC:
|
||||
element = metrics.stream().filter(e -> name.equals(e.getName())).findFirst();
|
||||
@@ -83,16 +81,31 @@ public class ModelSchema {
|
||||
}
|
||||
}
|
||||
|
||||
public Set<Long> getModelClusterSet() {
|
||||
if (CollectionUtils.isEmpty(modelRelas)) {
|
||||
return Sets.newHashSet();
|
||||
public TimeDefaultConfig getTagTypeTimeDefaultConfig() {
|
||||
if (queryConfig == null) {
|
||||
return null;
|
||||
}
|
||||
Set<Long> modelClusterSet = new HashSet<>();
|
||||
modelRelas.forEach(modelRela -> {
|
||||
modelClusterSet.add(modelRela.getToModelId());
|
||||
modelClusterSet.add(modelRela.getFromModelId());
|
||||
});
|
||||
return modelClusterSet;
|
||||
if (queryConfig.getTagTypeDefaultConfig() == null) {
|
||||
return null;
|
||||
}
|
||||
return queryConfig.getTagTypeDefaultConfig().getTimeDefaultConfig();
|
||||
}
|
||||
|
||||
public TimeDefaultConfig getMetricTypeTimeDefaultConfig() {
|
||||
if (queryConfig == null) {
|
||||
return null;
|
||||
}
|
||||
if (queryConfig.getMetricTypeDefaultConfig() == null) {
|
||||
return null;
|
||||
}
|
||||
return queryConfig.getMetricTypeDefaultConfig().getTimeDefaultConfig();
|
||||
}
|
||||
|
||||
public TagTypeDefaultConfig getTagTypeDefaultConfig() {
|
||||
if (queryConfig == null) {
|
||||
return null;
|
||||
}
|
||||
return queryConfig.getTagTypeDefaultConfig();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -13,26 +12,5 @@ public class ChatDefaultConfigReq {
|
||||
private List<Long> dimensionIds = new ArrayList<>();
|
||||
private List<Long> metricIds = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* default time span unit
|
||||
*/
|
||||
private Integer unit = 1;
|
||||
|
||||
/**
|
||||
* default time type: day
|
||||
* DAY, WEEK, MONTH, YEAR
|
||||
*/
|
||||
private String period = Constants.DAY;
|
||||
|
||||
private TimeMode timeMode = TimeMode.LAST;
|
||||
|
||||
public enum TimeMode {
|
||||
/**
|
||||
* date mode
|
||||
* LAST - a certain time
|
||||
* RECENT - a period time
|
||||
*/
|
||||
LAST, RECENT
|
||||
}
|
||||
|
||||
}
|
||||
@@ -13,7 +13,7 @@ public class PluginQueryReq {
|
||||
|
||||
private String type;
|
||||
|
||||
private String model;
|
||||
private String view;
|
||||
|
||||
private String pattern;
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ public class SimilarQueryReq {
|
||||
|
||||
private String queryText;
|
||||
|
||||
private String modelId;
|
||||
private Long viewId;
|
||||
|
||||
private Integer agentId;
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq.TimeMode;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
@@ -21,7 +21,7 @@ public class ChatDefaultRichConfigResp {
|
||||
private Integer unit = 1;
|
||||
|
||||
/**
|
||||
* default time type: day
|
||||
* default time type:
|
||||
* DAY, WEEK, MONTH, YEAR
|
||||
*/
|
||||
private String period = Constants.DAY;
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class DataInfo {
|
||||
|
||||
private Integer itemId;
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class EntityInfo {
|
||||
|
||||
private ModelInfo modelInfo = new ModelInfo();
|
||||
private ViewInfo viewInfo = new ViewInfo();
|
||||
private List<DataInfo> dimensions = new ArrayList<>();
|
||||
private List<DataInfo> metrics = new ArrayList<>();
|
||||
private String entityId;
|
||||
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class ModelInfo extends DataInfo implements Serializable {
|
||||
public class ViewInfo extends DataInfo implements Serializable {
|
||||
|
||||
private List<String> words;
|
||||
private String primaryKey;
|
||||
@@ -4,6 +4,9 @@ package com.tencent.supersonic.chat.core.agent;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -11,8 +14,6 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Data
|
||||
public class Agent extends RecordInfo {
|
||||
@@ -51,8 +52,8 @@ public class Agent extends RecordInfo {
|
||||
return enableSearch != null && enableSearch == 1;
|
||||
}
|
||||
|
||||
public static boolean containsAllModel(Set<Long> detectModelIds) {
|
||||
return !CollectionUtils.isEmpty(detectModelIds) && detectModelIds.contains(-1L);
|
||||
public static boolean containsAllModel(Set<Long> detectViewIds) {
|
||||
return !CollectionUtils.isEmpty(detectViewIds) && detectViewIds.contains(-1L);
|
||||
}
|
||||
|
||||
public List<NL2SQLTool> getParserTools(AgentToolType agentToolType) {
|
||||
@@ -64,12 +65,12 @@ public class Agent extends RecordInfo {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public Set<Long> getModelIds(AgentToolType agentToolType) {
|
||||
public Set<Long> getViewIds(AgentToolType agentToolType) {
|
||||
List<NL2SQLTool> commonAgentTools = getParserTools(agentToolType);
|
||||
if (CollectionUtils.isEmpty(commonAgentTools)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
return commonAgentTools.stream().map(NL2SQLTool::getModelIds)
|
||||
return commonAgentTools.stream().map(NL2SQLTool::getViewIds)
|
||||
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@Data
|
||||
public class DataAnalyticsTool extends AgentTool {
|
||||
|
||||
private Long modelId;
|
||||
|
||||
}
|
||||
@@ -1,16 +1,17 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
|
||||
|
||||
import java.util.List;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class NL2SQLTool extends AgentTool {
|
||||
|
||||
protected List<Long> modelIds;
|
||||
protected List<Long> viewIds;
|
||||
|
||||
}
|
||||
@@ -15,7 +15,7 @@ public class RuleParserTool extends NL2SQLTool {
|
||||
private List<String> queryTypes;
|
||||
|
||||
public boolean isContainsAllModel() {
|
||||
return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L);
|
||||
return CollectionUtils.isNotEmpty(viewIds) && viewIds.contains(-1L);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -16,10 +21,6 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* basic semantic correction functionality, offering common methods and an
|
||||
@@ -42,7 +43,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
||||
|
||||
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Set<Long> modelIds) {
|
||||
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long viewId) {
|
||||
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
@@ -52,7 +53,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
// support fieldName and field alias
|
||||
Map<String, String> result = dbAllFields.stream()
|
||||
.filter(entry -> modelIds.contains(entry.getModel()))
|
||||
.filter(entry -> viewId.equals(entry.getView()))
|
||||
.flatMap(schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
@@ -100,9 +101,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
List<SchemaElement> metrics = getMetricElements(queryContext, modelIds);
|
||||
Long viewId = semanticParseInfo.getView().getView();
|
||||
List<SchemaElement> metrics = getMetricElements(queryContext, viewId);
|
||||
|
||||
Map<String, String> metricToAggregate = metrics.stream()
|
||||
.map(schemaElement -> {
|
||||
@@ -127,9 +127,9 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Set<Long> modelIds) {
|
||||
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long viewId) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
return semanticSchema.getMetrics(modelIds);
|
||||
return semanticSchema.getMetrics(viewId);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
/**
|
||||
* Perform SQL corrections on the "From" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class FromCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String modelName = semanticParseInfo.getModel().getName();
|
||||
String correctSql = SqlParserReplaceHelper
|
||||
.replaceTable(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), modelName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctSql);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,18 +1,19 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Group by" section in S2SQL.
|
||||
@@ -28,8 +29,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
|
||||
Long viewId = semanticParseInfo.getViewId();
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
@@ -41,7 +41,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
return;
|
||||
}
|
||||
//add alias field name
|
||||
Set<String> dimensions = semanticSchema.getDimensions(modelIds).stream()
|
||||
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
|
||||
.flatMap(
|
||||
schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Having" section in S2SQL.
|
||||
*/
|
||||
@@ -31,11 +32,11 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
Long viewId = semanticParseInfo.getView().getView();
|
||||
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelIds).stream()
|
||||
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform schema corrections on the Schema information in S2QL.
|
||||
@@ -51,7 +52,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
|
||||
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getModel().getModelIds());
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getViewId());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
|
||||
@@ -1,24 +1,18 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.S2SqlDateHelper;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
@@ -27,6 +21,12 @@ import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||
*/
|
||||
@@ -73,7 +73,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryContext, semanticParseInfo.getModelId());
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryContext, semanticParseInfo.getViewId());
|
||||
if (StringUtils.isNotBlank(currentDate)) {
|
||||
correctS2SQL = SqlParserAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
correctS2SQL = SqlParserAddHelper.addWhere(
|
||||
@@ -99,8 +99,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(modelIds);
|
||||
Long viewId = semanticParseInfo.getViewId();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(viewId);
|
||||
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
@Service
|
||||
public class LoadRemoveService {
|
||||
@@ -31,7 +32,7 @@ public class LoadRemoveService {
|
||||
if (Objects.isNull(nature)) {
|
||||
return false;
|
||||
}
|
||||
Long modelId = NatureHelper.getModelId(nature);
|
||||
Long modelId = NatureHelper.getViewId(nature);
|
||||
if (Objects.nonNull(modelId)) {
|
||||
return !detectModelIds.contains(modelId);
|
||||
}
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import java.io.Serializable;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@Builder
|
||||
public class ModelInfoStat implements Serializable {
|
||||
|
||||
private long modelCount;
|
||||
|
||||
private long metricModelCount;
|
||||
|
||||
private long dimensionModelCount;
|
||||
|
||||
private long dimensionValueModelCount;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@Builder
|
||||
public class ViewInfoStat implements Serializable {
|
||||
|
||||
private long viewCount;
|
||||
|
||||
private long metricViewCount;
|
||||
|
||||
private long dimensionViewCount;
|
||||
|
||||
private long dimensionValueViewCount;
|
||||
|
||||
}
|
||||
@@ -4,13 +4,14 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* dimension word nature
|
||||
*/
|
||||
@@ -37,11 +38,11 @@ public class DimensionWordBuilder extends BaseWordBuilder {
|
||||
private DictWord getOnwWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
Long domainId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + domainId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
Long viewId = schemaElement.getView();
|
||||
String nature = DictWordType.NATURE_SPILT + viewId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.DIMENSION.getType();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + domainId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
nature = DictWordType.NATURE_SPILT + viewId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType() + DictWordType.DIMENSION.getType();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
|
||||
@@ -5,12 +5,13 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* dimension value wordNature
|
||||
*/
|
||||
@@ -26,8 +27,8 @@ public class EntityWordBuilder extends BaseWordBuilder {
|
||||
return result;
|
||||
}
|
||||
|
||||
Long domain = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + domain + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
Long view = schemaElement.getView();
|
||||
String nature = DictWordType.NATURE_SPILT + view + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.ENTITY.getType();
|
||||
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
|
||||
@@ -4,13 +4,14 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Metric DictWord
|
||||
*/
|
||||
@@ -37,11 +38,11 @@ public class MetricWordBuilder extends BaseWordBuilder {
|
||||
private DictWord getOnwWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
Long viewId = schemaElement.getView();
|
||||
String nature = DictWordType.NATURE_SPILT + viewId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.METRIC.getType();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
nature = DictWordType.NATURE_SPILT + viewId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType() + DictWordType.METRIC.getType();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
|
||||
@@ -4,11 +4,12 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* model word nature
|
||||
*/
|
||||
@@ -20,13 +21,13 @@ public class ModelWordBuilder extends BaseWordBuilder {
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
//modelName
|
||||
DictWord dictWord = buildDictWord(word, schemaElement.getModel());
|
||||
DictWord dictWord = buildDictWord(word, schemaElement.getView());
|
||||
result.add(dictWord);
|
||||
//alias
|
||||
List<String> aliasList = schemaElement.getAlias();
|
||||
if (CollectionUtils.isNotEmpty(aliasList)) {
|
||||
for (String alias : aliasList) {
|
||||
result.add(buildDictWord(alias, schemaElement.getModel()));
|
||||
result.add(buildDictWord(alias, schemaElement.getView()));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
||||
@@ -5,12 +5,13 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* dimension value wordNature
|
||||
*/
|
||||
@@ -26,8 +27,8 @@ public class ValueWordBuilder extends BaseWordBuilder {
|
||||
|
||||
schemaElement.getAlias().stream().forEach(value -> {
|
||||
DictWord dictWord = new DictWord();
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId();
|
||||
Long viewId = schemaElement.getView();
|
||||
String nature = DictWordType.NATURE_SPILT + viewId + DictWordType.NATURE_SPILT + schemaElement.getId();
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
dictWord.setWord(value);
|
||||
result.add(dictWord);
|
||||
|
||||
@@ -15,7 +15,7 @@ public class WordBuilderFactory {
|
||||
static {
|
||||
wordNatures.put(DictWordType.DIMENSION, new DimensionWordBuilder());
|
||||
wordNatures.put(DictWordType.METRIC, new MetricWordBuilder());
|
||||
wordNatures.put(DictWordType.MODEL, new ModelWordBuilder());
|
||||
wordNatures.put(DictWordType.VIEW, new ModelWordBuilder());
|
||||
wordNatures.put(DictWordType.ENTITY, new EntityWordBuilder());
|
||||
wordNatures.put(DictWordType.VALUE, new ValueWordBuilder());
|
||||
}
|
||||
|
||||
@@ -2,66 +2,67 @@ package com.tencent.supersonic.chat.core.knowledge.semantic;
|
||||
|
||||
import com.google.common.cache.Cache;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelSchemaResp;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticInterpreter implements SemanticInterpreter {
|
||||
|
||||
protected final Cache<String, List<ModelSchemaResp>> modelSchemaCache =
|
||||
protected final Cache<String, List<ViewSchemaResp>> viewSchemaCache =
|
||||
CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.SECONDS).build();
|
||||
|
||||
@SneakyThrows
|
||||
public List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable) {
|
||||
public List<ViewSchemaResp> fetchViewSchema(List<Long> ids, Boolean cacheEnable) {
|
||||
if (cacheEnable) {
|
||||
return modelSchemaCache.get(String.valueOf(ids), () -> {
|
||||
List<ModelSchemaResp> data = doFetchModelSchema(ids);
|
||||
modelSchemaCache.put(String.valueOf(ids), data);
|
||||
return viewSchemaCache.get(String.valueOf(ids), () -> {
|
||||
List<ViewSchemaResp> data = doFetchViewSchema(ids);
|
||||
viewSchemaCache.put(String.valueOf(ids), data);
|
||||
return data;
|
||||
});
|
||||
}
|
||||
List<ModelSchemaResp> data = doFetchModelSchema(ids);
|
||||
return data;
|
||||
return doFetchViewSchema(ids);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ModelSchema getModelSchema(Long model, Boolean cacheEnable) {
|
||||
public ViewSchema getViewSchema(Long viewId, Boolean cacheEnable) {
|
||||
List<Long> ids = new ArrayList<>();
|
||||
ids.add(model);
|
||||
List<ModelSchemaResp> modelSchemaResps = fetchModelSchema(ids, cacheEnable);
|
||||
if (!CollectionUtils.isEmpty(modelSchemaResps)) {
|
||||
Optional<ModelSchemaResp> modelSchemaResp = modelSchemaResps.stream()
|
||||
.filter(d -> d.getId().equals(model)).findFirst();
|
||||
if (modelSchemaResp.isPresent()) {
|
||||
ModelSchemaResp modelSchema = modelSchemaResp.get();
|
||||
return ModelSchemaBuilder.build(modelSchema);
|
||||
ids.add(viewId);
|
||||
List<ViewSchemaResp> viewSchemaResps = fetchViewSchema(ids, cacheEnable);
|
||||
if (!CollectionUtils.isEmpty(viewSchemaResps)) {
|
||||
Optional<ViewSchemaResp> viewSchemaResp = viewSchemaResps.stream()
|
||||
.filter(d -> d.getId().equals(viewId)).findFirst();
|
||||
if (viewSchemaResp.isPresent()) {
|
||||
ViewSchemaResp viewSchema = viewSchemaResp.get();
|
||||
return ViewSchemaBuilder.build(viewSchema);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelSchema> getModelSchema() {
|
||||
return getModelSchema(new ArrayList<>());
|
||||
public List<ViewSchema> getViewSchema() {
|
||||
return getViewSchema(new ArrayList<>());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelSchema> getModelSchema(List<Long> ids) {
|
||||
List<ModelSchema> domainSchemaList = new ArrayList<>();
|
||||
public List<ViewSchema> getViewSchema(List<Long> ids) {
|
||||
List<ViewSchema> domainSchemaList = new ArrayList<>();
|
||||
|
||||
for (ModelSchemaResp resp : fetchModelSchema(ids, true)) {
|
||||
domainSchemaList.add(ModelSchemaBuilder.build(resp));
|
||||
for (ViewSchemaResp resp : fetchViewSchema(ids, true)) {
|
||||
domainSchemaList.add(ViewSchemaBuilder.build(resp));
|
||||
}
|
||||
|
||||
return domainSchemaList;
|
||||
}
|
||||
|
||||
protected abstract List<ModelSchemaResp> doFetchModelSchema(List<Long> ids);
|
||||
protected abstract List<ViewSchemaResp> doFetchViewSchema(List<Long> ids);
|
||||
|
||||
}
|
||||
|
||||
@@ -2,34 +2,33 @@ package com.tencent.supersonic.chat.core.knowledge.semantic;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaFilterReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ViewFilterReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
import com.tencent.supersonic.headless.server.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import com.tencent.supersonic.headless.server.service.QueryService;
|
||||
import com.tencent.supersonic.headless.server.service.SchemaService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
|
||||
@@ -44,7 +43,7 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
if (StringUtils.isNotBlank(queryStructReq.getCorrectS2SQL())) {
|
||||
QuerySqlReq querySqlReq = new QuerySqlReq();
|
||||
querySqlReq.setSql(queryStructReq.getCorrectS2SQL());
|
||||
querySqlReq.setModelIds(queryStructReq.getModelIdSet());
|
||||
querySqlReq.setViewId(queryStructReq.getViewId());
|
||||
querySqlReq.setParams(new ArrayList<>());
|
||||
return queryByS2SQL(querySqlReq, user);
|
||||
}
|
||||
@@ -68,19 +67,11 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
public SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user) {
|
||||
queryService = ContextUtils.getBean(QueryService.class);
|
||||
return queryService.queryDimValue(queryDimValueReq, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelSchemaResp> doFetchModelSchema(List<Long> ids) {
|
||||
ModelSchemaFilterReq filter = new ModelSchemaFilterReq();
|
||||
filter.setModelIds(ids);
|
||||
public List<ViewSchemaResp> doFetchViewSchema(List<Long> ids) {
|
||||
ViewFilterReq filter = new ViewFilterReq();
|
||||
filter.setViewIds(ids);
|
||||
schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
User user = User.getFakeUser();
|
||||
return schemaService.fetchModelSchema(filter, user);
|
||||
return schemaService.fetchViewSchema(filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -90,9 +81,9 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelResp> getModelList(AuthType authType, Long domainId, User user) {
|
||||
public List<ViewResp> getViewList(Long domainId) {
|
||||
schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
return schemaService.getModelList(user, authType, domainId);
|
||||
return schemaService.getViewList(domainId);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,19 +1,11 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.semantic;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.LIST_LOWER;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.PAGESIZE_LOWER;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TOTAL_LOWER;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TRUE_LOWER;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.google.gson.Gson;
|
||||
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
|
||||
import com.tencent.supersonic.auth.api.authentication.constant.UserConstants;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.core.config.DefaultSemanticConfig;
|
||||
import com.tencent.supersonic.common.pojo.ResultData;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.common.pojo.enums.ReturnCode;
|
||||
import com.tencent.supersonic.common.pojo.exception.CommonException;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -21,10 +13,8 @@ import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.S2ThreadContext;
|
||||
import com.tencent.supersonic.common.util.ThreadContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ModelSchemaFilterReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
@@ -32,15 +22,9 @@ import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
@@ -54,6 +38,17 @@ import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.LIST_LOWER;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.PAGESIZE_LOWER;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TOTAL_LOWER;
|
||||
|
||||
@Slf4j
|
||||
public class RemoteSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
|
||||
@@ -130,50 +125,6 @@ public class RemoteSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
throw new CommonException(responseBody.getCode(), responseBody.getMsg());
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user) {
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
return searchByRestTemplate(defaultSemanticConfig.getSemanticUrl()
|
||||
+ defaultSemanticConfig.getQueryDimValuePath(),
|
||||
new Gson().toJson(queryDimValueReq));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelSchemaResp> doFetchModelSchema(List<Long> ids) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.set(UserConstants.INTERNAL, TRUE_LOWER);
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
fillToken(headers);
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
|
||||
String semanticUrl = defaultSemanticConfig.getSemanticUrl();
|
||||
String fetchModelSchemaPath = defaultSemanticConfig.getFetchModelSchemaPath();
|
||||
URI requestUrl = UriComponentsBuilder.fromHttpUrl(semanticUrl + fetchModelSchemaPath)
|
||||
.build().encode().toUri();
|
||||
ModelSchemaFilterReq filter = new ModelSchemaFilterReq();
|
||||
filter.setModelIds(ids);
|
||||
ParameterizedTypeReference<ResultData<List<ModelSchemaResp>>> responseTypeRef =
|
||||
new ParameterizedTypeReference<ResultData<List<ModelSchemaResp>>>() {
|
||||
};
|
||||
|
||||
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(filter), headers);
|
||||
|
||||
try {
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
ResponseEntity<ResultData<List<ModelSchemaResp>>> responseEntity = restTemplate.exchange(
|
||||
requestUrl, HttpMethod.POST, entity, responseTypeRef);
|
||||
ResultData<List<ModelSchemaResp>> responseBody = responseEntity.getBody();
|
||||
log.debug("ApiResponse<fetchModelSchema> responseBody:{}", responseBody);
|
||||
if (ReturnCode.SUCCESS.getCode() == responseBody.getCode()) {
|
||||
List<ModelSchemaResp> data = responseBody.getData();
|
||||
return data;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("fetchModelSchema interface error", e);
|
||||
}
|
||||
throw new RuntimeException("fetchModelSchema interface error");
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DomainResp> getDomainList(User user) {
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
@@ -183,19 +134,6 @@ public class RemoteSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
return JsonUtil.toList(JsonUtil.toString(domainDescListObject), DomainResp.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ModelResp> getModelList(AuthType authType, Long domainId, User user) {
|
||||
if (domainId == null) {
|
||||
domainId = 0L;
|
||||
}
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
String url = String.format("%s?domainId=%s&authType=%s",
|
||||
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchModelListPath(),
|
||||
domainId, authType.toString());
|
||||
Object domainDescListObject = fetchHttpResult(url, null, HttpMethod.GET);
|
||||
return JsonUtil.toList(JsonUtil.toString(domainDescListObject), ModelResp.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> ExplainResp explain(ExplainSqlReq<T> explainResp, User user) throws Exception {
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
@@ -310,4 +248,13 @@ public class RemoteSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
return pageInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<ViewSchemaResp> doFetchViewSchema(List<Long> ids) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ViewResp> getViewList(Long domainId) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,22 +2,20 @@ package com.tencent.supersonic.chat.core.knowledge.semantic;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -38,15 +36,13 @@ public interface SemanticInterpreter {
|
||||
|
||||
SemanticQueryResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
|
||||
|
||||
SemanticQueryResp queryByS2SQL(QuerySqlReq querySqlReq, User user);
|
||||
SemanticQueryResp queryByS2SQL(QuerySqlReq querySQLReq, User user);
|
||||
|
||||
SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
|
||||
List<ViewSchema> getViewSchema();
|
||||
|
||||
List<ModelSchema> getModelSchema();
|
||||
List<ViewSchema> getViewSchema(List<Long> ids);
|
||||
|
||||
List<ModelSchema> getModelSchema(List<Long> ids);
|
||||
|
||||
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
|
||||
ViewSchema getViewSchema(Long model, Boolean cacheEnable);
|
||||
|
||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
|
||||
|
||||
@@ -54,10 +50,10 @@ public interface SemanticInterpreter {
|
||||
|
||||
List<DomainResp> getDomainList(User user);
|
||||
|
||||
List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
|
||||
|
||||
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
||||
|
||||
List<ModelSchemaResp> fetchModelSchema(List<Long> ids, Boolean cacheEnable);
|
||||
List<ViewSchemaResp> fetchViewSchema(List<Long> ids, Boolean cacheEnable);
|
||||
|
||||
List<ViewResp> getViewList(Long domainId);
|
||||
|
||||
}
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
package com.tencent.supersonic.chat.core.knowledge.semantic;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.RelatedSchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.DimValueMap;
|
||||
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
@@ -22,20 +23,19 @@ import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ModelSchemaBuilder {
|
||||
public class ViewSchemaBuilder {
|
||||
|
||||
public static ModelSchema build(ModelSchemaResp resp) {
|
||||
ModelSchema modelSchema = new ModelSchema();
|
||||
public static ViewSchema build(ViewSchemaResp resp) {
|
||||
ViewSchema viewSchema = new ViewSchema();
|
||||
viewSchema.setQueryConfig(resp.getQueryConfig());
|
||||
SchemaElement model = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.view(resp.getId())
|
||||
.id(resp.getId())
|
||||
.name(resp.getName())
|
||||
.bizName(resp.getBizName())
|
||||
.type(SchemaElementType.MODEL)
|
||||
.alias(SchemaItem.getAliasList(resp.getAlias()))
|
||||
.type(SchemaElementType.VIEW)
|
||||
.build();
|
||||
modelSchema.setModel(model);
|
||||
modelSchema.setModelRelas(resp.getModelRelas());
|
||||
viewSchema.setView(model);
|
||||
|
||||
Set<SchemaElement> metrics = new HashSet<>();
|
||||
for (MetricSchemaResp metric : resp.getMetrics()) {
|
||||
@@ -43,7 +43,7 @@ public class ModelSchemaBuilder {
|
||||
List<String> alias = SchemaItem.getAliasList(metric.getAlias());
|
||||
|
||||
SchemaElement metricToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.view(resp.getId())
|
||||
.id(metric.getId())
|
||||
.name(metric.getName())
|
||||
.bizName(metric.getBizName())
|
||||
@@ -56,7 +56,7 @@ public class ModelSchemaBuilder {
|
||||
metrics.add(metricToAdd);
|
||||
|
||||
}
|
||||
modelSchema.getMetrics().addAll(metrics);
|
||||
viewSchema.getMetrics().addAll(metrics);
|
||||
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
Set<SchemaElement> dimensionValues = new HashSet<>();
|
||||
@@ -83,7 +83,7 @@ public class ModelSchemaBuilder {
|
||||
|
||||
}
|
||||
SchemaElement dimToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.view(resp.getId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
.bizName(dim.getBizName())
|
||||
@@ -95,7 +95,7 @@ public class ModelSchemaBuilder {
|
||||
dimensions.add(dimToAdd);
|
||||
|
||||
SchemaElement dimValueToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.view(resp.getId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
.bizName(dim.getBizName())
|
||||
@@ -106,7 +106,7 @@ public class ModelSchemaBuilder {
|
||||
dimensionValues.add(dimValueToAdd);
|
||||
if (dim.getIsTag() == 1) {
|
||||
SchemaElement tagToAdd = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.view(resp.getId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
.bizName(dim.getBizName())
|
||||
@@ -118,14 +118,14 @@ public class ModelSchemaBuilder {
|
||||
tags.add(tagToAdd);
|
||||
}
|
||||
}
|
||||
modelSchema.getDimensions().addAll(dimensions);
|
||||
modelSchema.getDimensionValues().addAll(dimensionValues);
|
||||
modelSchema.getTags().addAll(tags);
|
||||
viewSchema.getDimensions().addAll(dimensions);
|
||||
viewSchema.getDimensionValues().addAll(dimensionValues);
|
||||
viewSchema.getTags().addAll(tags);
|
||||
|
||||
DimSchemaResp dim = resp.getPrimaryKey();
|
||||
if (dim != null) {
|
||||
SchemaElement entity = SchemaElement.builder()
|
||||
.model(resp.getId())
|
||||
.view(resp.getId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
.bizName(dim.getBizName())
|
||||
@@ -133,9 +133,9 @@ public class ModelSchemaBuilder {
|
||||
.useCnt(dim.getUseCnt())
|
||||
.alias(dim.getEntityAlias())
|
||||
.build();
|
||||
modelSchema.setEntity(entity);
|
||||
viewSchema.setEntity(entity);
|
||||
}
|
||||
return modelSchema;
|
||||
return viewSchema;
|
||||
}
|
||||
|
||||
private static List<RelatedSchemaElement> getRelateSchemaElement(MetricSchemaResp metricSchemaResp) {
|
||||
@@ -1,21 +1,22 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
@Slf4j
|
||||
public abstract class BaseMapper implements SchemaMapper {
|
||||
@@ -25,7 +26,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
|
||||
String simpleName = this.getClass().getSimpleName();
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
|
||||
log.debug("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getViewElementMatches());
|
||||
|
||||
try {
|
||||
doMap(queryContext);
|
||||
@@ -34,13 +35,13 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
}
|
||||
|
||||
long cost = System.currentTimeMillis() - startTime;
|
||||
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
|
||||
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getViewElementMatches());
|
||||
}
|
||||
|
||||
public abstract void doMap(QueryContext queryContext);
|
||||
|
||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch newElementMatch) {
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getViewElementMatches();
|
||||
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId, new ArrayList<>());
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = modelElementMatches.get(modelId);
|
||||
@@ -66,14 +67,14 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
}
|
||||
}
|
||||
|
||||
public SchemaElement getSchemaElement(Long modelId, SchemaElementType elementType, Long elementID,
|
||||
public SchemaElement getSchemaElement(Long viewId, SchemaElementType elementType, Long elementID,
|
||||
SemanticSchema semanticSchema) {
|
||||
SchemaElement element = new SchemaElement();
|
||||
ModelSchema modelSchema = semanticSchema.getModelSchemaMap().get(modelId);
|
||||
if (Objects.isNull(modelSchema)) {
|
||||
ViewSchema viewSchema = semanticSchema.getViewSchemaMap().get(viewId);
|
||||
if (Objects.isNull(viewSchema)) {
|
||||
return null;
|
||||
}
|
||||
SchemaElement elementDb = modelSchema.getElement(elementType, elementID);
|
||||
SchemaElement elementDb = viewSchema.getElement(elementType, elementID);
|
||||
if (Objects.isNull(elementDb)) {
|
||||
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
|
||||
return null;
|
||||
|
||||
@@ -3,6 +3,12 @@ package com.tencent.supersonic.chat.core.mapper;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -13,11 +19,6 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -27,15 +28,15 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
private MapperHelper mapperHelper;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectModelIds) {
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<Term> terms, Set<Long> detectViewIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||
log.debug("terms:{},,detectViewIds:{}", terms, detectViewIds);
|
||||
|
||||
List<T> detects = detect(queryContext, terms, detectModelIds);
|
||||
List<T> detects = detect(queryContext, terms, detectViewIds);
|
||||
Map<MatchText, List<T>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
@@ -102,9 +103,9 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
}
|
||||
|
||||
public List<T> getMatches(QueryContext queryContext, List<Term> terms) {
|
||||
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getModelId(), queryContext.getAgent());
|
||||
terms = filterByModelIds(terms, detectModelIds);
|
||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, detectModelIds);
|
||||
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
|
||||
terms = filterByViewId(terms, viewIds);
|
||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, viewIds);
|
||||
List<T> matches = new ArrayList<>();
|
||||
if (Objects.isNull(matchResult)) {
|
||||
return matches;
|
||||
@@ -119,17 +120,17 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
return matches;
|
||||
}
|
||||
|
||||
public List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
|
||||
public List<Term> filterByViewId(List<Term> terms, Set<Long> viewIds) {
|
||||
logTerms(terms);
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||
if (CollectionUtils.isNotEmpty(viewIds)) {
|
||||
terms = terms.stream().filter(term -> {
|
||||
Long modelId = NatureHelper.getModelId(term.getNature().toString());
|
||||
if (Objects.nonNull(modelId)) {
|
||||
return detectModelIds.contains(modelId);
|
||||
Long viewId = NatureHelper.getViewId(term.getNature().toString());
|
||||
if (Objects.nonNull(viewId)) {
|
||||
return viewIds.contains(viewId);
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
log.info("terms filter by modelIds:{}", detectModelIds);
|
||||
log.info("terms filter by viewId:{}", viewIds);
|
||||
logTerms(terms);
|
||||
}
|
||||
return terms;
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DatabaseMapResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -14,11 +20,6 @@ import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
||||
@@ -59,7 +60,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
}
|
||||
Set<Long> modelIds = mapperHelper.getModelIds(queryContext.getModelId(), queryContext.getAgent());
|
||||
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||
|
||||
@@ -72,9 +73,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
continue;
|
||||
}
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(modelIds)) {
|
||||
if (!CollectionUtils.isEmpty(viewIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
|
||||
.filter(schemaElement -> viewIds.contains(schemaElement.getView()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
@@ -98,7 +99,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
Double metricDimensionThresholdConfig = optimizationConfig.getMetricDimensionThresholdConfig();
|
||||
Double metricDimensionMinThresholdConfig = optimizationConfig.getMetricDimensionMinThresholdConfig();
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getModelElementMatches();
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getViewElementMatches();
|
||||
|
||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* A mapper capable of converting the VALUE of entity dimension values into ID types.
|
||||
*/
|
||||
@@ -23,12 +24,12 @@ public class EntityMapper extends BaseMapper {
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
|
||||
for (Long viewId : schemaMapInfo.getMatchedViewInfos()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(viewId);
|
||||
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElement entity = getEntity(modelId, queryContext);
|
||||
SchemaElement entity = getEntity(viewId, queryContext);
|
||||
if (entity == null || entity.getId() == null) {
|
||||
continue;
|
||||
}
|
||||
@@ -64,9 +65,9 @@ public class EntityMapper extends BaseMapper {
|
||||
return false;
|
||||
}
|
||||
|
||||
private SchemaElement getEntity(Long modelId, QueryContext queryContext) {
|
||||
private SchemaElement getEntity(Long viewId, QueryContext queryContext) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
ModelSchema modelSchema = semanticSchema.getModelSchemaMap().get(modelId);
|
||||
ViewSchema modelSchema = semanticSchema.getViewSchemaMap().get(viewId);
|
||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||
return modelSchema.getEntity();
|
||||
}
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.core.knowledge.DatabaseMapResult;
|
||||
import com.tencent.supersonic.chat.core.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/***
|
||||
* A mapper that recognizes schema elements with keyword.
|
||||
@@ -56,8 +57,8 @@ public class KeywordMapper extends BaseMapper {
|
||||
|
||||
for (HanlpMapResult hanlpMapResult : mapResults) {
|
||||
for (String nature : hanlpMapResult.getNatures()) {
|
||||
Long modelId = NatureHelper.getModelId(nature);
|
||||
if (Objects.isNull(modelId)) {
|
||||
Long viewId = NatureHelper.getViewId(nature);
|
||||
if (Objects.isNull(viewId)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
|
||||
@@ -65,8 +66,8 @@ public class KeywordMapper extends BaseMapper {
|
||||
continue;
|
||||
}
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
SchemaElement element = getSchemaElement(modelId, elementType, elementID,
|
||||
queryContext.getSemanticSchema());
|
||||
SchemaElement element = getSchemaElement(viewId, elementType,
|
||||
elementID, queryContext.getSemanticSchema());
|
||||
if (element == null) {
|
||||
continue;
|
||||
}
|
||||
@@ -82,7 +83,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
.detectWord(hanlpMapResult.getDetectWord())
|
||||
.build();
|
||||
|
||||
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -103,12 +104,12 @@ public class KeywordMapper extends BaseMapper {
|
||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getModel(), schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getView(), schemaElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getView());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@@ -5,6 +5,11 @@ import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.utils.NatureHelper;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
@@ -12,10 +17,6 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Data
|
||||
@Service
|
||||
@@ -61,7 +62,7 @@ public class MapperHelper {
|
||||
*/
|
||||
public boolean existDimensionValues(List<String> natures) {
|
||||
for (String nature : natures) {
|
||||
if (NatureHelper.isDimensionValueModelId(nature)) {
|
||||
if (NatureHelper.isDimensionValueViewId(nature)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -81,33 +82,33 @@ public class MapperHelper {
|
||||
detectSegment.length());
|
||||
}
|
||||
|
||||
public Set<Long> getModelIds(Long modelId, Agent agent) {
|
||||
public Set<Long> getViewIds(Long viewId, Agent agent) {
|
||||
|
||||
Set<Long> detectModelIds = new HashSet<>();
|
||||
Set<Long> detectViewIds = new HashSet<>();
|
||||
if (Objects.nonNull(agent)) {
|
||||
detectModelIds = agent.getModelIds(null);
|
||||
detectViewIds = agent.getViewIds(null);
|
||||
}
|
||||
//contains all
|
||||
if (Agent.containsAllModel(detectModelIds)) {
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
if (Agent.containsAllModel(detectViewIds)) {
|
||||
if (Objects.nonNull(viewId) && viewId > 0) {
|
||||
Set<Long> result = new HashSet<>();
|
||||
result.add(modelId);
|
||||
result.add(viewId);
|
||||
return result;
|
||||
}
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
if (Objects.nonNull(detectModelIds)) {
|
||||
detectModelIds = detectModelIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
|
||||
if (Objects.nonNull(detectViewIds)) {
|
||||
detectViewIds = detectViewIds.stream().filter(entry -> entry > 0).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
if (Objects.nonNull(modelId) && modelId > 0 && Objects.nonNull(detectModelIds)) {
|
||||
if (detectModelIds.contains(modelId)) {
|
||||
if (Objects.nonNull(viewId) && viewId > 0 && Objects.nonNull(detectViewIds)) {
|
||||
if (detectViewIds.contains(viewId)) {
|
||||
Set<Long> result = new HashSet<>();
|
||||
result.add(modelId);
|
||||
result.add(viewId);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
return detectModelIds;
|
||||
return detectViewIds;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,52 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.utils.ModelClusterBuilder;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/***
|
||||
* ModelClusterMapper build a cluster from
|
||||
* connectable data models based on model-rela configuration
|
||||
* and generate SchemaModelClusterMapInfo
|
||||
*/
|
||||
public class ModelClusterMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
List<ModelCluster> modelClusters = buildModelClusterMatched(schemaMapInfo, semanticSchema);
|
||||
Map<String, List<SchemaElementMatch>> modelClusterElementMatches = new HashMap<>();
|
||||
for (ModelCluster modelCluster : modelClusters) {
|
||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
||||
if (modelCluster.getModelIds().contains(modelId)) {
|
||||
modelClusterElementMatches.computeIfAbsent(modelCluster.getKey(), k -> new ArrayList<>())
|
||||
.addAll(schemaMapInfo.getMatchedElements(modelId));
|
||||
}
|
||||
}
|
||||
}
|
||||
SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
||||
modelClusterMapInfo.setModelElementMatches(modelClusterElementMatches);
|
||||
queryContext.setModelClusterMapInfo(modelClusterMapInfo);
|
||||
}
|
||||
|
||||
private List<ModelCluster> buildModelClusterMatched(SchemaMapInfo schemaMapInfo,
|
||||
SemanticSchema semanticSchema) {
|
||||
Set<Long> matchedModels = schemaMapInfo.getMatchedModels();
|
||||
List<ModelCluster> modelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
|
||||
return modelClusters.stream().map(ModelCluster::getModelIds).peek(modelCluster -> {
|
||||
modelCluster.removeIf(model -> !matchedModels.contains(model));
|
||||
}).filter(modelCluster -> modelCluster.size() > 0).map(ModelCluster::build).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -10,11 +10,12 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.core.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class QueryFilterMapper implements SchemaMapper {
|
||||
@@ -23,22 +24,22 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
Long modelId = queryContext.getModelId();
|
||||
if (modelId == null || modelId <= 0) {
|
||||
Long viewId = queryContext.getViewId();
|
||||
if (viewId == null || viewId <= 0) {
|
||||
return;
|
||||
}
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
clearOtherSchemaElementMatch(modelId, schemaMapInfo);
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(modelId);
|
||||
clearOtherSchemaElementMatch(viewId, schemaMapInfo);
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(viewId);
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
schemaMapInfo.setMatchedElements(modelId, schemaElementMatches);
|
||||
schemaMapInfo.setMatchedElements(viewId, schemaElementMatches);
|
||||
}
|
||||
addValueSchemaElementMatch(queryContext, schemaElementMatches);
|
||||
}
|
||||
|
||||
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getModelElementMatches().entrySet()) {
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getViewElementMatches().entrySet()) {
|
||||
if (!entry.getKey().equals(modelId)) {
|
||||
entry.getValue().clear();
|
||||
}
|
||||
@@ -60,7 +61,7 @@ public class QueryFilterMapper implements SchemaMapper {
|
||||
.name(String.valueOf(filter.getValue()))
|
||||
.type(SchemaElementType.VALUE)
|
||||
.bizName(filter.getBizName())
|
||||
.model(queryContext.getModelId())
|
||||
.view(queryContext.getViewId())
|
||||
.build();
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionPromptGenerator;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGeneration;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGenerationFactory;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* LLMProxy based on langchain4j Java version.
|
||||
*/
|
||||
@@ -37,12 +38,12 @@ public class JavaLLMProxy implements LLMProxy {
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
|
||||
|
||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
LLMResp result = sqlGeneration.generation(llmReq, modelClusterKey);
|
||||
String modelName = llmReq.getSchema().getViewName();
|
||||
LLMResp result = sqlGeneration.generation(llmReq, viewId);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
return result;
|
||||
|
||||
@@ -15,7 +15,7 @@ public interface LLMProxy {
|
||||
|
||||
boolean isSkip(QueryContext queryContext);
|
||||
|
||||
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
|
||||
LLMResp query2sql(LLMReq llmReq, Long viewId);
|
||||
|
||||
FunctionResp requestFunction(FunctionReq functionReq);
|
||||
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionCallConfig;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -28,6 +25,10 @@ import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
|
||||
/**
|
||||
* PythonLLMProxy sends requests to LangChain-based python service.
|
||||
*/
|
||||
@@ -47,10 +48,10 @@ public class PythonLLMProxy implements LLMProxy {
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp query2sql(LLMReq llmReq, Long viewId) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, modelId:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
log.info("requestLLM request, viewId:{},llmReq:{}", viewId, llmReq);
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
try {
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
|
||||
|
||||
@@ -13,13 +13,14 @@ import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
|
||||
@@ -49,7 +50,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
return QueryType.ID;
|
||||
}
|
||||
//1. entity queryType
|
||||
Set<Long> modelIds = parseInfo.getModel().getModelIds();
|
||||
Long viewId = parseInfo.getViewId();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
@@ -58,12 +59,12 @@ public class QueryTypeParser implements SemanticParser {
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isNotEmpty(whereFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(modelIds).stream().map(SchemaElement::getName)
|
||||
Set<String> ids = semanticSchema.getEntities(viewId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
|
||||
return QueryType.ID;
|
||||
}
|
||||
Set<String> tags = semanticSchema.getTags(modelIds).stream().map(SchemaElement::getName)
|
||||
Set<String> tags = semanticSchema.getTags(viewId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
|
||||
return QueryType.TAG;
|
||||
@@ -72,7 +73,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
//2. metric queryType
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(viewId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
|
||||
@@ -18,7 +18,6 @@ import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.HashMap;
|
||||
@@ -56,13 +55,13 @@ public abstract class PluginParser implements SemanticParser {
|
||||
|
||||
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
|
||||
Plugin plugin = pluginRecallResult.getPlugin();
|
||||
Set<Long> modelIds = pluginRecallResult.getModelIds();
|
||||
Set<Long> viewIds = pluginRecallResult.getViewIds();
|
||||
if (plugin.isContainsAllModel()) {
|
||||
modelIds = Sets.newHashSet(-1L);
|
||||
viewIds = Sets.newHashSet(-1L);
|
||||
}
|
||||
for (Long modelId : modelIds) {
|
||||
for (Long viewId : viewIds) {
|
||||
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin,
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(viewId, plugin,
|
||||
queryContext, pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
@@ -75,20 +74,19 @@ public abstract class PluginParser implements SemanticParser {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext);
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin,
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long viewId, Plugin plugin,
|
||||
QueryContext queryContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches =
|
||||
queryContext.getModelClusterMapInfo().getMatchedElements(modelId);
|
||||
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
|
||||
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
|
||||
modelId = plugin.getModelList().get(0);
|
||||
if (viewId == null && !CollectionUtils.isEmpty(plugin.getViewList())) {
|
||||
viewId = plugin.getViewList().get(0);
|
||||
}
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
}
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||
semanticParseInfo.setModel(ModelCluster.build(Sets.newHashSet(modelId)));
|
||||
semanticParseInfo.setView(queryContext.getSemanticSchema().getView(viewId));
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||
pluginParseResult.setPlugin(plugin);
|
||||
|
||||
@@ -57,15 +57,15 @@ public class EmbeddingRecallParser extends PluginParser {
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
||||
log.info("embedding plugin resolve: {}", pair);
|
||||
if (pair.getLeft()) {
|
||||
Set<Long> modelList = pair.getRight();
|
||||
if (CollectionUtils.isEmpty(modelList)) {
|
||||
Set<Long> viewList = pair.getRight();
|
||||
if (CollectionUtils.isEmpty(viewList)) {
|
||||
continue;
|
||||
}
|
||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||
double distance = embeddingRetrieval.getDistance();
|
||||
double score = queryContext.getQueryText().length() * (1 - distance);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
|
||||
.plugin(plugin).viewIds(viewList).score(score).distance(distance).build();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
|
||||
@@ -12,15 +12,16 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* FunctionCallParser is an implementation of a recall plugin based on FunctionCall
|
||||
*/
|
||||
@@ -56,19 +57,19 @@ public class FunctionCallParser extends PluginParser {
|
||||
plugin.setParseMode(ParseMode.FUNCTION_CALL);
|
||||
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
|
||||
if (pluginResolveResult.getLeft()) {
|
||||
Set<Long> modelList = pluginResolveResult.getRight();
|
||||
if (CollectionUtils.isEmpty(modelList)) {
|
||||
Set<Long> viewList = pluginResolveResult.getRight();
|
||||
if (CollectionUtils.isEmpty(viewList)) {
|
||||
return null;
|
||||
}
|
||||
double score = queryContext.getQueryText().length();
|
||||
return PluginRecallResult.builder().plugin(plugin).modelIds(modelList).score(score).build();
|
||||
return PluginRecallResult.builder().plugin(plugin).viewIds(viewList).score(score).build();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public FunctionResp functionCall(QueryContext queryContext) {
|
||||
List<PluginParseConfig> pluginToFunctionCall =
|
||||
getPluginToFunctionCall(queryContext.getModelId(), queryContext);
|
||||
getPluginToFunctionCall(queryContext.getViewId(), queryContext);
|
||||
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
|
||||
log.info("function call parser, plugin is empty, skip");
|
||||
return null;
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class HeuristicModelResolver implements ModelResolver {
|
||||
|
||||
protected static String selectModelBySchemaElementMatchScore(Map<String, SemanticQuery> modelQueryModes,
|
||||
SchemaModelClusterMapInfo schemaMap) {
|
||||
//model count priority
|
||||
String modelIdByModelCount = getModelIdByMatchModelScore(schemaMap);
|
||||
if (Objects.nonNull(modelIdByModelCount)) {
|
||||
log.info("selectModel by model count:{}", modelIdByModelCount);
|
||||
return modelIdByModelCount;
|
||||
}
|
||||
|
||||
Map<String, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
|
||||
if (modelTypeMap.size() == 1) {
|
||||
String modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
|
||||
if (modelQueryModes.containsKey(modelSelect)) {
|
||||
log.info("selectModel with only one Model [{}]", modelSelect);
|
||||
return modelSelect;
|
||||
}
|
||||
} else {
|
||||
Map.Entry<String, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
|
||||
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
|
||||
.sorted((o1, o2) -> {
|
||||
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
||||
if (difference == 0) {
|
||||
return (int) ((o2.getValue().getMaxSimilarity()
|
||||
- o1.getValue().getMaxSimilarity()) * 100);
|
||||
}
|
||||
return difference;
|
||||
}).findFirst().orElse(null);
|
||||
if (maxModel != null) {
|
||||
log.info("selectModel with multiple Models [{}]", maxModel.getKey());
|
||||
return maxModel.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static String getModelIdByMatchModelScore(SchemaModelClusterMapInfo schemaMap) {
|
||||
Map<String, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
|
||||
// calculate model match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||
Map<String, Double> modelIdToModelScore = new HashMap<>();
|
||||
if (Objects.nonNull(modelElementMatches)) {
|
||||
for (Entry<String, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
|
||||
String modelId = modelElementMatch.getKey();
|
||||
List<Double> modelMatchesScore = modelElementMatch.getValue().stream()
|
||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||
.filter(elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(modelMatchesScore)) {
|
||||
// get sum of model match score
|
||||
double score = modelMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||
modelIdToModelScore.put(modelId, score);
|
||||
}
|
||||
}
|
||||
Entry<String, Double> maxModelScore = modelIdToModelScore.entrySet().stream()
|
||||
.max(Comparator.comparingDouble(o -> o.getValue())).orElse(null);
|
||||
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelScore, modelIdToModelScore);
|
||||
if (Objects.nonNull(maxModelScore)) {
|
||||
return maxModelScore.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Map<String, ModelMatchResult> getModelTypeMap(SchemaModelClusterMapInfo schemaMap) {
|
||||
Map<String, ModelMatchResult> modelCount = new HashMap<>();
|
||||
for (Map.Entry<String, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||
if (!modelCount.containsKey(entry.getKey())) {
|
||||
modelCount.put(entry.getKey(), new ModelMatchResult());
|
||||
}
|
||||
ModelMatchResult modelMatchResult = modelCount.get(entry.getKey());
|
||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||
schemaElementMatches.stream()
|
||||
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
||||
schemaElementMatch.getElement().getType()));
|
||||
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
||||
.sorted((o1, o2) ->
|
||||
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
||||
).findFirst().orElse(null);
|
||||
if (schemaElementMatchMax != null) {
|
||||
modelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||
}
|
||||
modelMatchResult.setCount(schemaElementTypes.size());
|
||||
|
||||
}
|
||||
}
|
||||
return modelCount;
|
||||
}
|
||||
|
||||
public String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
|
||||
SchemaModelClusterMapInfo mapInfo = queryContext.getModelClusterMapInfo();
|
||||
Set<String> matchedModelClusters = mapInfo.getElementMatchesByModelIds(restrictiveModels).keySet();
|
||||
Long modelId = queryContext.getModelId();
|
||||
if (Objects.nonNull(modelId) && modelId > 0) {
|
||||
if (CollectionUtils.isEmpty(restrictiveModels) || restrictiveModels.contains(modelId)) {
|
||||
return getModelClusterByModelId(modelId, matchedModelClusters);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
Map<String, SemanticQuery> modelQueryModes = new HashMap<>();
|
||||
for (String matchedModel : matchedModelClusters) {
|
||||
modelQueryModes.put(matchedModel, null);
|
||||
}
|
||||
if (modelQueryModes.size() == 1) {
|
||||
return modelQueryModes.keySet().stream().findFirst().get();
|
||||
}
|
||||
return selectModelBySchemaElementMatchScore(modelQueryModes, mapInfo);
|
||||
}
|
||||
|
||||
private String getModelClusterByModelId(Long modelId, Set<String> modelClusterKeySet) {
|
||||
for (String modelClusterKey : modelClusterKeySet) {
|
||||
if (ModelCluster.build(modelClusterKey).getModelIds().contains(modelId)) {
|
||||
return modelClusterKey;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,136 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class HeuristicViewResolver implements ViewResolver {
|
||||
|
||||
protected static Long selectViewBySchemaElementMatchScore(Map<Long, SemanticQuery> viewQueryModes,
|
||||
SchemaMapInfo schemaMap) {
|
||||
//view count priority
|
||||
Long viewIdByViewCount = getViewIdByMatchViewScore(schemaMap);
|
||||
if (Objects.nonNull(viewIdByViewCount)) {
|
||||
log.info("selectView by view count:{}", viewIdByViewCount);
|
||||
return viewIdByViewCount;
|
||||
}
|
||||
|
||||
Map<Long, ViewMatchResult> viewTypeMap = getViewTypeMap(schemaMap);
|
||||
if (viewTypeMap.size() == 1) {
|
||||
Long viewSelect = new ArrayList<>(viewTypeMap.entrySet()).get(0).getKey();
|
||||
if (viewQueryModes.containsKey(viewSelect)) {
|
||||
log.info("selectView with only one View [{}]", viewSelect);
|
||||
return viewSelect;
|
||||
}
|
||||
} else {
|
||||
Map.Entry<Long, ViewMatchResult> maxView = viewTypeMap.entrySet().stream()
|
||||
.filter(entry -> viewQueryModes.containsKey(entry.getKey()))
|
||||
.sorted((o1, o2) -> {
|
||||
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
||||
if (difference == 0) {
|
||||
return (int) ((o2.getValue().getMaxSimilarity()
|
||||
- o1.getValue().getMaxSimilarity()) * 100);
|
||||
}
|
||||
return difference;
|
||||
}).findFirst().orElse(null);
|
||||
if (maxView != null) {
|
||||
log.info("selectView with multiple Views [{}]", maxView.getKey());
|
||||
return maxView.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static Long getViewIdByMatchViewScore(SchemaMapInfo schemaMap) {
|
||||
Map<Long, List<SchemaElementMatch>> viewElementMatches = schemaMap.getViewElementMatches();
|
||||
// calculate view match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||
Map<Long, Double> viewIdToViewScore = new HashMap<>();
|
||||
if (Objects.nonNull(viewElementMatches)) {
|
||||
for (Entry<Long, List<SchemaElementMatch>> viewElementMatch : viewElementMatches.entrySet()) {
|
||||
Long viewId = viewElementMatch.getKey();
|
||||
List<Double> viewMatchesScore = viewElementMatch.getValue().stream()
|
||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||
.filter(elementMatch -> SchemaElementType.VIEW.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(viewMatchesScore)) {
|
||||
// get sum of view match score
|
||||
double score = viewMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||
viewIdToViewScore.put(viewId, score);
|
||||
}
|
||||
}
|
||||
Entry<Long, Double> maxViewScore = viewIdToViewScore.entrySet().stream()
|
||||
.max(Comparator.comparingDouble(Entry::getValue)).orElse(null);
|
||||
log.info("maxViewCount:{},viewIdToViewCount:{}", maxViewScore, viewIdToViewScore);
|
||||
if (Objects.nonNull(maxViewScore)) {
|
||||
return maxViewScore.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Map<Long, ViewMatchResult> getViewTypeMap(SchemaMapInfo schemaMap) {
|
||||
Map<Long, ViewMatchResult> viewCount = new HashMap<>();
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getViewElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||
if (!viewCount.containsKey(entry.getKey())) {
|
||||
viewCount.put(entry.getKey(), new ViewMatchResult());
|
||||
}
|
||||
ViewMatchResult viewMatchResult = viewCount.get(entry.getKey());
|
||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||
schemaElementMatches.stream()
|
||||
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
||||
schemaElementMatch.getElement().getType()));
|
||||
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
||||
.sorted((o1, o2) ->
|
||||
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
||||
).findFirst().orElse(null);
|
||||
if (schemaElementMatchMax != null) {
|
||||
viewMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||
}
|
||||
viewMatchResult.setCount(schemaElementTypes.size());
|
||||
|
||||
}
|
||||
}
|
||||
return viewCount;
|
||||
}
|
||||
|
||||
public Long resolve(QueryContext queryContext, Set<Long> agentViewIds) {
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
Set<Long> matchedViews = mapInfo.getMatchedViewInfos();
|
||||
Long viewId = queryContext.getViewId();
|
||||
if (Objects.nonNull(viewId) && viewId > 0) {
|
||||
if (CollectionUtils.isEmpty(agentViewIds) || agentViewIds.contains(viewId)) {
|
||||
return viewId;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
matchedViews.retainAll(agentViewIds);
|
||||
Map<Long, SemanticQuery> viewQueryModes = new HashMap<>();
|
||||
for (Long viewIds : matchedViews) {
|
||||
viewQueryModes.put(viewIds, null);
|
||||
}
|
||||
if (viewQueryModes.size() == 1) {
|
||||
return viewQueryModes.keySet().stream().findFirst().get();
|
||||
}
|
||||
return selectViewBySchemaElementMatchScore(viewQueryModes, mapInfo);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
@@ -11,21 +12,23 @@ import com.tencent.supersonic.chat.core.config.LLMParserConfig;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ViewSchemaResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
@@ -35,12 +38,6 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@@ -63,42 +60,20 @@ public class LLMRequestService {
|
||||
return false;
|
||||
}
|
||||
|
||||
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx) {
|
||||
public Long getViewId(QueryContext queryCtx) {
|
||||
Agent agent = queryCtx.getAgent();
|
||||
Set<Long> distinctModelIds = new HashSet<>();
|
||||
Set<Long> agentViewIds = new HashSet<>();
|
||||
if (Objects.nonNull(agent)) {
|
||||
distinctModelIds = agent.getModelIds(AgentToolType.NL2SQL_LLM);
|
||||
agentViewIds = agent.getViewIds(AgentToolType.NL2SQL_LLM);
|
||||
}
|
||||
if (llmParserConfig.getAllModel()) {
|
||||
ModelCluster modelCluster = ModelCluster.build(distinctModelIds);
|
||||
if (!CollectionUtils.isEmpty(queryCtx.getCandidateQueries())) {
|
||||
queryCtx.getCandidateQueries().stream().forEach(o -> {
|
||||
if (LLMSqlQuery.QUERY_MODE.equals(o.getParseInfo().getQueryMode())) {
|
||||
o.getParseInfo().setModel(modelCluster);
|
||||
}
|
||||
});
|
||||
}
|
||||
SemanticQuery semanticQuery = QueryManager.createQuery(LLMSqlQuery.QUERY_MODE);
|
||||
semanticQuery.getParseInfo().setModel(modelCluster);
|
||||
List<SchemaElementMatch> schemaElementMatches = new ArrayList<>();
|
||||
distinctModelIds.stream().forEach(o -> {
|
||||
if (!CollectionUtils.isEmpty(queryCtx.getMapInfo().getMatchedElements(o))) {
|
||||
schemaElementMatches.addAll(queryCtx.getMapInfo().getMatchedElements(o));
|
||||
}
|
||||
});
|
||||
queryCtx.getModelClusterMapInfo().setMatchedElements(modelCluster.getKey(), schemaElementMatches);
|
||||
return modelCluster;
|
||||
if (Agent.containsAllModel(agentViewIds)) {
|
||||
agentViewIds = new HashSet<>();
|
||||
}
|
||||
if (Agent.containsAllModel(distinctModelIds)) {
|
||||
distinctModelIds = new HashSet<>();
|
||||
}
|
||||
ModelResolver modelResolver = ComponentFactory.getModelResolver();
|
||||
String modelCluster = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
|
||||
log.info("resolve modelId:{},llmParser Models:{}", modelCluster, distinctModelIds);
|
||||
return ModelCluster.build(modelCluster);
|
||||
ViewResolver viewResolver = ComponentFactory.getModelResolver();
|
||||
return viewResolver.resolve(queryCtx, agentViewIds);
|
||||
}
|
||||
|
||||
public NL2SQLTool getParserTool(QueryContext queryCtx, Set<Long> modelIdSet) {
|
||||
public NL2SQLTool getParserTool(QueryContext queryCtx, Long viewId) {
|
||||
Agent agent = queryCtx.getAgent();
|
||||
if (Objects.isNull(agent)) {
|
||||
return null;
|
||||
@@ -106,39 +81,33 @@ public class LLMRequestService {
|
||||
List<NL2SQLTool> commonAgentTools = agent.getParserTools(AgentToolType.NL2SQL_LLM);
|
||||
Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
|
||||
.filter(tool -> {
|
||||
List<Long> modelIds = tool.getModelIds();
|
||||
if (Agent.containsAllModel(new HashSet<>(modelIds))) {
|
||||
List<Long> viewIds = tool.getViewIds();
|
||||
if (Agent.containsAllModel(new HashSet<>(viewIds))) {
|
||||
return true;
|
||||
}
|
||||
for (Long modelId : modelIdSet) {
|
||||
if (modelIds.contains(modelId)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
return viewIds.contains(viewId);
|
||||
})
|
||||
.findFirst();
|
||||
return llmParserTool.orElse(null);
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
|
||||
ModelCluster modelCluster, List<ElementValue> linkingValues) {
|
||||
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long viewId,
|
||||
SemanticSchema semanticSchema, List<ElementValue> linkingValues) {
|
||||
Map<Long, String> viewIdToName = semanticSchema.getViewIdToName();
|
||||
String queryText = queryCtx.getQueryText();
|
||||
|
||||
LLMReq llmReq = new LLMReq();
|
||||
llmReq.setQueryText(queryText);
|
||||
Long firstModelId = modelCluster.getFirstModel();
|
||||
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
|
||||
llmReq.setFilterCondition(filterCondition);
|
||||
|
||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||
llmSchema.setModelName(modelIdToName.get(firstModelId));
|
||||
llmSchema.setDomainName(modelIdToName.get(firstModelId));
|
||||
llmSchema.setViewName(viewIdToName.get(viewId));
|
||||
llmSchema.setDomainName(viewIdToName.get(viewId));
|
||||
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, modelCluster, llmParserConfig);
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, viewId, llmParserConfig);
|
||||
|
||||
String priorExts = getPriorExts(modelCluster.getModelIds(), fieldNameList);
|
||||
String priorExts = getPriorExts(viewId, fieldNameList);
|
||||
llmReq.setPriorExts(priorExts);
|
||||
|
||||
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
|
||||
@@ -151,7 +120,7 @@ public class LLMRequestService {
|
||||
}
|
||||
llmReq.setLinking(linking);
|
||||
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, firstModelId);
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, viewId);
|
||||
if (StringUtils.isEmpty(currentDate)) {
|
||||
currentDate = DateUtils.getBeforeDate(0);
|
||||
}
|
||||
@@ -160,29 +129,28 @@ public class LLMRequestService {
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
public LLMResp requestLLM(LLMReq llmReq, String modelClusterKey) {
|
||||
return ComponentFactory.getLLMProxy().query2sql(llmReq, modelClusterKey);
|
||||
public LLMResp requestLLM(LLMReq llmReq, Long viewId) {
|
||||
return ComponentFactory.getLLMProxy().query2sql(llmReq, viewId);
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long viewId,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(queryCtx, modelCluster, llmParserConfig);
|
||||
Set<String> results = getTopNFieldNames(queryCtx, viewId, llmParserConfig);
|
||||
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelCluster);
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, viewId);
|
||||
|
||||
results.addAll(fieldNameList);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
private String getPriorExts(Set<Long> modelIds, List<String> fieldNameList) {
|
||||
private String getPriorExts(Long viewId, List<String> fieldNameList) {
|
||||
StringBuilder extraInfoSb = new StringBuilder();
|
||||
List<ModelSchemaResp> modelSchemaResps = semanticInterpreter.fetchModelSchema(
|
||||
new ArrayList<>(modelIds), true);
|
||||
if (!CollectionUtils.isEmpty(modelSchemaResps)) {
|
||||
|
||||
ModelSchemaResp modelSchemaResp = modelSchemaResps.get(0);
|
||||
Map<String, String> fieldNameToDataFormatType = modelSchemaResp.getMetrics()
|
||||
List<ViewSchemaResp> viewSchemaResps = semanticInterpreter.fetchViewSchema(
|
||||
Lists.newArrayList(viewId), true);
|
||||
if (!CollectionUtils.isEmpty(viewSchemaResps)) {
|
||||
ViewSchemaResp viewSchemaResp = viewSchemaResps.get(0);
|
||||
Map<String, String> fieldNameToDataFormatType = viewSchemaResp.getMetrics()
|
||||
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
|
||||
.flatMap(metricSchemaResp -> {
|
||||
Set<Pair<String, String>> result = new HashSet<>();
|
||||
@@ -210,11 +178,9 @@ public class LLMRequestService {
|
||||
return extraInfoSb.toString();
|
||||
}
|
||||
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, modelCluster);
|
||||
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
||||
.getMatchedElements(modelCluster.getKey());
|
||||
protected List<ElementValue> getValueList(QueryContext queryCtx, Long viewId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, viewId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(viewId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
@@ -234,22 +200,21 @@ public class LLMRequestService {
|
||||
return new ArrayList<>(valueMatches);
|
||||
}
|
||||
|
||||
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long viewId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
return semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
||||
return semanticSchema.getDimensions(viewId).stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
private Set<String> getTopNFieldNames(QueryContext queryCtx, ModelCluster modelCluster,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long viewId, LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
Set<String> results = semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
|
||||
Set<String> results = semanticSchema.getDimensions(viewId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(modelCluster.getModelIds()).stream()
|
||||
Set<String> metrics = semanticSchema.getMetrics(viewId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getMetricTopN())
|
||||
.map(entry -> entry.getName())
|
||||
@@ -259,10 +224,9 @@ public class LLMRequestService {
|
||||
return results;
|
||||
}
|
||||
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, ModelCluster modelCluster) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, modelCluster);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
|
||||
.getMatchedElements(modelCluster.getKey());
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long viewId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, viewId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(viewId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@@ -28,10 +28,9 @@ public class LLMResponseService {
|
||||
}
|
||||
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(LLMSqlQuery.QUERY_MODE);
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
parseInfo.setModel(parseResult.getModelCluster());
|
||||
parseInfo.setView(queryCtx.getSemanticSchema().getView(parseResult.getViewId()));
|
||||
NL2SQLTool commonAgentTool = parseResult.getCommonAgentTool();
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getModelClusterMapInfo()
|
||||
.getMatchedElements(parseInfo.getModelClusterKey()));
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getViewId()));
|
||||
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, parseResult);
|
||||
@@ -42,7 +41,6 @@ public class LLMResponseService {
|
||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setS2SQL(s2SQL);
|
||||
parseInfo.setModel(parseResult.getModelCluster());
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
@@ -9,14 +9,13 @@ import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class LLMSqlParser implements SemanticParser {
|
||||
@@ -30,31 +29,30 @@ public class LLMSqlParser implements SemanticParser {
|
||||
}
|
||||
try {
|
||||
//2.get modelId from queryCtx and chatCtx.
|
||||
ModelCluster modelCluster = requestService.getModelCluster(queryCtx, chatCtx);
|
||||
if (StringUtils.isBlank(modelCluster.getKey())) {
|
||||
Long viewId = requestService.getViewId(queryCtx);
|
||||
if (viewId == null) {
|
||||
return;
|
||||
}
|
||||
//3.get agent tool and determine whether to skip this parser.
|
||||
NL2SQLTool commonAgentTool = requestService.getParserTool(queryCtx, modelCluster.getModelIds());
|
||||
NL2SQLTool commonAgentTool = requestService.getParserTool(queryCtx, viewId);
|
||||
if (Objects.isNull(commonAgentTool)) {
|
||||
log.info("no tool in this agent, skip {}", LLMSqlParser.class);
|
||||
return;
|
||||
}
|
||||
//4.construct a request, call the API for the large model, and retrieve the results.
|
||||
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, modelCluster);
|
||||
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, viewId);
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, semanticSchema, modelCluster, linkingValues);
|
||||
LLMResp llmResp = requestService.requestLLM(llmReq, modelCluster.getKey());
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, viewId, semanticSchema, linkingValues);
|
||||
LLMResp llmResp = requestService.requestLLM(llmReq, viewId);
|
||||
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
//5. deduplicate the SQL result list and build parserInfo
|
||||
modelCluster.buildName(semanticSchema.getModelIdToName());
|
||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
|
||||
ParseResult parseResult = ParseResult.builder()
|
||||
.modelCluster(modelCluster)
|
||||
.viewId(viewId)
|
||||
.commonAgentTool(commonAgentTool)
|
||||
.llmReq(llmReq)
|
||||
.llmResp(llmResp)
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import java.util.Set;
|
||||
|
||||
public interface ModelResolver {
|
||||
|
||||
String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
|
||||
|
||||
}
|
||||
@@ -11,11 +11,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.slf4j.Logger;
|
||||
@@ -24,6 +19,12 @@ import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
@@ -42,9 +43,9 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
@@ -2,19 +2,16 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -22,6 +19,10 @@ import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
@@ -31,7 +32,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private ChatLanguageModel chatLanguageModel;
|
||||
|
||||
@Autowired
|
||||
private SqlExamplarLoader sqlExamplarLoader;
|
||||
private SqlExamplarLoader sqlExampleLoader;
|
||||
|
||||
@Autowired
|
||||
private OptimizationConfig optimizationConfig;
|
||||
@@ -40,10 +41,10 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||
//1.retriever sqlExamples
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
||||
|
||||
@@ -5,7 +5,6 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
@@ -19,7 +18,7 @@ import java.util.List;
|
||||
@NoArgsConstructor
|
||||
public class ParseResult {
|
||||
|
||||
private ModelCluster modelCluster;
|
||||
private Long viewId;
|
||||
|
||||
private LLMReq llmReq;
|
||||
|
||||
|
||||
@@ -12,9 +12,9 @@ public interface SqlGeneration {
|
||||
/***
|
||||
* generate llmResp (sql, schemaLink, prompt, etc.) through LLMReq.
|
||||
* @param llmReq
|
||||
* @param modelClusterKey
|
||||
* @param viewId
|
||||
* @return
|
||||
*/
|
||||
LLMResp generation(LLMReq llmReq, String modelClusterKey);
|
||||
LLMResp generation(LLMReq llmReq, Long viewId);
|
||||
|
||||
}
|
||||
|
||||
@@ -2,14 +2,15 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@@ -95,7 +96,7 @@ public class SqlPromptGenerator {
|
||||
}
|
||||
|
||||
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
|
||||
String modelName = llmReq.getSchema().getModelName();
|
||||
String modelName = llmReq.getSchema().getViewName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
List<ElementValue> linking = llmReq.getLinking();
|
||||
String currentDate = llmReq.getCurrentDate();
|
||||
|
||||
@@ -11,10 +11,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -22,6 +18,11 @@ import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
@Service
|
||||
public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
|
||||
@@ -39,9 +40,9 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
|
||||
@@ -11,9 +11,6 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -21,6 +18,10 @@ import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
@@ -39,8 +40,8 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
|
||||
private SqlPromptGenerator sqlPromptGenerator;
|
||||
|
||||
@Override
|
||||
public LLMResp generation(LLMReq llmReq, String modelClusterKey) {
|
||||
keyPipelineLog.info("modelClusterKey:{},llmReq:{}", modelClusterKey, llmReq);
|
||||
public LLMResp generation(LLMReq llmReq, Long viewId) {
|
||||
keyPipelineLog.info("viewId:{},llmReq:{}", viewId, llmReq);
|
||||
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlCollectionName(), optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ModelMatchResult {
|
||||
public class ViewMatchResult {
|
||||
private Integer count = 0;
|
||||
private double maxSimilarity;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
public interface ViewResolver {
|
||||
|
||||
Long resolve(QueryContext queryContext, Set<Long> restrictiveModels);
|
||||
|
||||
}
|
||||
@@ -11,12 +11,12 @@ import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class AgentCheckParser implements SemanticParser {
|
||||
|
||||
@@ -52,16 +52,13 @@ public class AgentCheckParser implements SemanticParser {
|
||||
return !tool.getQueryTypes().contains(QueryType.METRIC.name());
|
||||
}
|
||||
}
|
||||
if (CollectionUtils.isEmpty(tool.getModelIds())) {
|
||||
if (CollectionUtils.isEmpty(tool.getViewIds())) {
|
||||
return true;
|
||||
}
|
||||
if (tool.isContainsAllModel()) {
|
||||
return false;
|
||||
}
|
||||
if (new HashSet<>(tool.getModelIds())
|
||||
.containsAll(query.getParseInfo().getModel().getModelIds())) {
|
||||
return false;
|
||||
}
|
||||
return !tool.getViewIds().contains(query.getParseInfo().getViewId());
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.core.parser.sql.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
@@ -12,8 +11,8 @@ import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricSemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricTagQuery;
|
||||
import com.tencent.supersonic.chat.core.utils.ModelClusterBuilder;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
@@ -23,8 +22,6 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* ContextInheritParser tries to inherit certain schema elements from context
|
||||
@@ -42,7 +39,7 @@ public class ContextInheritParser implements SemanticParser {
|
||||
SchemaElementType.VALUE, Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY, Arrays.asList(SchemaElementType.ENTITY)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.TAG, Arrays.asList(SchemaElementType.TAG)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.MODEL, Arrays.asList(SchemaElementType.MODEL)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.VIEW, Arrays.asList(SchemaElementType.VIEW)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ID, Arrays.asList(SchemaElementType.ID))
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
|
||||
@@ -51,12 +48,13 @@ public class ContextInheritParser implements SemanticParser {
|
||||
if (!shouldInherit(queryContext)) {
|
||||
return;
|
||||
}
|
||||
ModelCluster modelCluster = getMatchedModelCluster(queryContext, chatContext);
|
||||
if (modelCluster == null) {
|
||||
Long viewId = getMatchedView(queryContext, chatContext);
|
||||
if (viewId == null) {
|
||||
return;
|
||||
}
|
||||
List<SchemaElementMatch> elementMatches = queryContext.getModelClusterMapInfo()
|
||||
.getMatchedElements(modelCluster.getKey());
|
||||
|
||||
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo().getMatchedElements(viewId);
|
||||
|
||||
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
||||
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
||||
SchemaElementType matchType = match.getElement().getType();
|
||||
@@ -72,17 +70,17 @@ public class ContextInheritParser implements SemanticParser {
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(queryContext, chatContext);
|
||||
if (existSameQuery(query.getParseInfo().getModelClusterKey(), query.getQueryMode(), queryContext)) {
|
||||
if (existSameQuery(query.getParseInfo().getViewId(), query.getQueryMode(), queryContext)) {
|
||||
continue;
|
||||
}
|
||||
queryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean existSameQuery(String modelClusterKey, String queryMode, QueryContext queryContext) {
|
||||
private boolean existSameQuery(Long viewId, String queryMode, QueryContext queryContext) {
|
||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||
if (semanticQuery.getQueryMode().equals(queryMode)
|
||||
&& semanticQuery.getParseInfo().getModelClusterKey().equals(modelClusterKey)) {
|
||||
&& semanticQuery.getParseInfo().getViewId().equals(viewId)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -111,25 +109,16 @@ public class ContextInheritParser implements SemanticParser {
|
||||
return metricModelQueries.size() == queryContext.getCandidateQueries().size();
|
||||
}
|
||||
|
||||
protected ModelCluster getMatchedModelCluster(QueryContext queryContext, ChatContext chatContext) {
|
||||
String contextModelClusterKey = chatContext.getParseInfo().getModelClusterKey();
|
||||
if (StringUtils.isBlank(contextModelClusterKey)) {
|
||||
protected Long getMatchedView(QueryContext queryContext, ChatContext chatContext) {
|
||||
Long viewId = chatContext.getParseInfo().getViewId();
|
||||
if (viewId == null) {
|
||||
return null;
|
||||
}
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
List<ModelCluster> allModelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
|
||||
Set<String> queryModelClusters = queryContext.getModelClusterMapInfo().getMatchedModelClusters();
|
||||
ModelCluster contextModelCluster = ModelCluster.build(contextModelClusterKey);
|
||||
for (String cluster : queryModelClusters) {
|
||||
ModelCluster queryModelCluster = ModelCluster.build(cluster);
|
||||
for (ModelCluster modelCluster : allModelClusters) {
|
||||
if (modelCluster.getModelIds().containsAll(contextModelCluster.getModelIds())
|
||||
&& modelCluster.getModelIds().containsAll(queryModelCluster.getModelIds())) {
|
||||
return queryModelCluster;
|
||||
}
|
||||
}
|
||||
Set<Long> queryViews = queryContext.getMapInfo().getMatchedViewInfos();
|
||||
if (queryViews.contains(viewId)) {
|
||||
return viewId;
|
||||
}
|
||||
return null;
|
||||
return viewId;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.rule;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -27,10 +27,10 @@ public class RuleSqlParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
SchemaModelClusterMapInfo modelClusterMapInfo = queryContext.getModelClusterMapInfo();
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
// iterate all schemaElementMatches to resolve query mode
|
||||
for (String modelClusterKey : modelClusterMapInfo.getMatchedModelClusters()) {
|
||||
List<SchemaElementMatch> elementMatches = modelClusterMapInfo.getMatchedElements(modelClusterKey);
|
||||
for (Long viewId : mapInfo.getMatchedViewInfos()) {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(viewId);
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(queryContext, chatContext);
|
||||
|
||||
@@ -20,7 +20,7 @@ public class Plugin extends RecordInfo {
|
||||
*/
|
||||
private String type;
|
||||
|
||||
private List<Long> modelList = Lists.newArrayList();
|
||||
private List<Long> viewList = Lists.newArrayList();
|
||||
|
||||
/**
|
||||
* description, for parsing
|
||||
@@ -52,7 +52,7 @@ public class Plugin extends RecordInfo {
|
||||
}
|
||||
|
||||
public boolean isContainsAllModel() {
|
||||
return CollectionUtils.isNotEmpty(modelList) && modelList.contains(-1L);
|
||||
return CollectionUtils.isNotEmpty(viewList) && viewList.contains(-1L);
|
||||
}
|
||||
|
||||
public Long getDefaultMode() {
|
||||
|
||||
@@ -23,6 +23,12 @@ import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
@@ -32,11 +38,6 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@@ -265,14 +266,14 @@ public class PluginManager {
|
||||
}
|
||||
|
||||
private static Set<Long> getPluginMatchedModel(Plugin plugin, QueryContext queryContext) {
|
||||
Set<Long> matchedModel = queryContext.getMapInfo().getMatchedModels();
|
||||
Set<Long> matchedViews = queryContext.getMapInfo().getMatchedViewInfos();
|
||||
if (plugin.isContainsAllModel()) {
|
||||
return Sets.newHashSet(plugin.getDefaultMode());
|
||||
}
|
||||
List<Long> modelIds = plugin.getModelList();
|
||||
List<Long> modelIds = plugin.getViewList();
|
||||
Set<Long> pluginMatchedModel = Sets.newHashSet();
|
||||
for (Long modelId : modelIds) {
|
||||
if (matchedModel.contains(modelId)) {
|
||||
if (matchedViews.contains(modelId)) {
|
||||
pluginMatchedModel.add(modelId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
@@ -14,7 +15,7 @@ public class PluginRecallResult {
|
||||
|
||||
private Plugin plugin;
|
||||
|
||||
private Set<Long> modelIds;
|
||||
private Set<Long> viewIds;
|
||||
|
||||
private double score;
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.core.pojo;
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
@@ -12,15 +11,16 @@ import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@@ -30,14 +30,13 @@ public class QueryContext {
|
||||
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Long modelId;
|
||||
private Long viewId;
|
||||
private User user;
|
||||
private boolean saveAnswer = true;
|
||||
private Integer agentId;
|
||||
private QueryFilters queryFilters;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
|
||||
@JsonIgnore
|
||||
private SemanticSchema semanticSchema;
|
||||
@JsonIgnore
|
||||
|
||||
@@ -19,15 +19,16 @@ import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
@ToString
|
||||
@@ -48,7 +49,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
explainSqlReq = ExplainSqlReq.builder()
|
||||
.queryTypeEnum(QueryType.SQL)
|
||||
.queryReq(QueryReqBuilder.buildS2SQLReq(
|
||||
sqlInfo.getCorrectS2SQL(), parseInfo.getModel().getModelIds()
|
||||
sqlInfo.getCorrectS2SQL(), parseInfo.getViewId()
|
||||
))
|
||||
.build();
|
||||
} else {
|
||||
@@ -83,7 +84,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
|
||||
protected void convertBizNameToName(SemanticSchema semanticSchema, QueryStructReq queryStructReq) {
|
||||
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getModelIdSet());
|
||||
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getViewId());
|
||||
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
|
||||
|
||||
List<Order> orders = queryStructReq.getOrders();
|
||||
@@ -100,18 +101,17 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
List<String> groups = queryStructReq.getGroups();
|
||||
if (CollectionUtils.isNotEmpty(groups)) {
|
||||
groups = groups.stream().map(group -> bizNameToName.get(group)).collect(Collectors.toList());
|
||||
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
|
||||
queryStructReq.setGroups(groups);
|
||||
}
|
||||
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
dimensionFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
dimensionFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
List<Filter> metricFilters = queryStructReq.getMetricFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
metricFilters.stream().forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
queryStructReq.setModelName(parseInfo.getModelName());
|
||||
}
|
||||
|
||||
protected void initS2SqlByStruct(SemanticSchema semanticSchema) {
|
||||
@@ -121,9 +121,9 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
convertBizNameToName(semanticSchema, queryStructReq);
|
||||
QuerySqlReq querySqlReq = queryStructReq.convert(queryStructReq);
|
||||
parseInfo.getSqlInfo().setS2SQL(querySqlReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(querySqlReq.getSql());
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert(queryStructReq);
|
||||
parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package com.tencent.supersonic.chat.core.query.llm.s2sql;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class LLMReq {
|
||||
|
||||
@@ -35,7 +36,7 @@ public class LLMReq {
|
||||
|
||||
private String domainName;
|
||||
|
||||
private String modelName;
|
||||
private String viewName;
|
||||
|
||||
private List<String> fieldNameList;
|
||||
|
||||
|
||||
@@ -42,8 +42,8 @@ public class LLMSqlQuery extends LLMSemanticQuery {
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
String querySql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
QuerySqlReq querySqlReq = QueryReqBuilder.buildS2SQLReq(querySql, parseInfo.getModel().getModelIds());
|
||||
SemanticQueryResp queryResp = semanticInterpreter.queryByS2SQL(querySqlReq, user);
|
||||
QuerySqlReq querySQLReq = QueryReqBuilder.buildS2SQLReq(querySql, parseInfo.getViewId());
|
||||
SemanticQueryResp queryResp = semanticInterpreter.queryByS2SQL(querySQLReq, user);
|
||||
|
||||
log.info("queryByS2SQL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ public abstract class PluginSemanticQuery extends BaseSemanticQuery {
|
||||
if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) {
|
||||
for (ParamOption paramOption : webPage.getParamOptions()) {
|
||||
if (paramOption.getModelId() != null
|
||||
&& !parseInfo.getModel().getModelIds().contains(paramOption.getModelId())) {
|
||||
&& !parseInfo.getViewId().equals(paramOption.getModelId())) {
|
||||
continue;
|
||||
}
|
||||
paramOptions.add(paramOption);
|
||||
|
||||
@@ -3,14 +3,15 @@ package com.tencent.supersonic.chat.core.query.rule;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@@ -24,7 +25,7 @@ public class QueryMatcher {
|
||||
|
||||
public QueryMatcher() {
|
||||
for (SchemaElementType type : SchemaElementType.values()) {
|
||||
if (type.equals(SchemaElementType.MODEL)) {
|
||||
if (type.equals(SchemaElementType.VIEW)) {
|
||||
elementOptionMap.put(type, QueryMatchOption.optional());
|
||||
} else {
|
||||
elementOptionMap.put(type, QueryMatchOption.unused());
|
||||
|
||||
@@ -2,9 +2,6 @@
|
||||
package com.tencent.supersonic.chat.core.query.rule;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
@@ -14,17 +11,19 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.BaseSemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -103,11 +102,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
|
||||
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
|
||||
Set<Long> modelIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
|
||||
.map(SchemaElement::getModel).collect(Collectors.toSet());
|
||||
ModelCluster modelCluster = ModelCluster.build(modelIds);
|
||||
modelCluster.buildName(semanticSchema.getModelIdToName());
|
||||
parseInfo.setModel(modelCluster);
|
||||
Set<Long> viewIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
|
||||
.map(SchemaElement::getView).collect(Collectors.toSet());
|
||||
parseInfo.setView(semanticSchema.getView(viewIds.iterator().next()));
|
||||
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
|
||||
Map<Long, List<SchemaElementMatch>> id2Values = new HashMap<>();
|
||||
|
||||
@@ -192,7 +189,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
public QueryResult execute(User user) {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
|
||||
if (StringUtils.isBlank(parseInfo.getModelClusterKey()) || StringUtils.isEmpty(queryMode)
|
||||
if (parseInfo.getViewId() == null || StringUtils.isEmpty(queryMode)
|
||||
|| !QueryManager.containsRuleQuery(queryMode)) {
|
||||
// reach here some error may happen
|
||||
log.error("not find QueryMode");
|
||||
@@ -233,7 +230,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
public QueryResult multiStructExecute(User user) {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
|
||||
if (StringUtils.isBlank(parseInfo.getModelClusterKey()) || StringUtils.isEmpty(queryMode)
|
||||
if (parseInfo.getViewId() != null || StringUtils.isEmpty(queryMode)
|
||||
|| !QueryManager.containsRuleQuery(queryMode)) {
|
||||
// reach here some error may happen
|
||||
log.error("not find QueryMode");
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
package com.tencent.supersonic.chat.core.query.rule.metric;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.OPTIONAL;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_MOST;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.OPTIONAL;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_MOST;
|
||||
@Component
|
||||
public class MetricModelQuery extends MetricSemanticQuery {
|
||||
|
||||
@@ -15,7 +14,7 @@ public class MetricModelQuery extends MetricSemanticQuery {
|
||||
|
||||
public MetricModelQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(MODEL, OPTIONAL, AT_MOST, 1);
|
||||
queryMatcher.addOption(SchemaElementType.VIEW, OPTIONAL, AT_MOST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,26 +1,11 @@
|
||||
package com.tencent.supersonic.chat.core.query.rule.metric;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT_INT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT_INT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TIMES_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.AggregateInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.MetricInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.core.config.AggregatorConfig;
|
||||
@@ -33,10 +18,15 @@ import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.RatioOverType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.text.DecimalFormat;
|
||||
import java.time.DayOfWeek;
|
||||
import java.time.LocalDate;
|
||||
@@ -53,8 +43,19 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT_INT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT_INT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TIMES_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
|
||||
|
||||
@Slf4j
|
||||
public abstract class MetricSemanticQuery extends RuleSemanticQuery {
|
||||
@@ -75,30 +76,26 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
|
||||
@Override
|
||||
public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) {
|
||||
super.fillParseInfo(queryContext, chatContext);
|
||||
|
||||
parseInfo.setLimit(METRIC_MAX_RESULTS);
|
||||
if (parseInfo.getDateInfo() == null) {
|
||||
ChatConfigRichResp chatConfig = queryContext.getModelIdToChatRichConfig().get(parseInfo.getModelId());
|
||||
ChatDefaultRichConfigResp defaultConfig = chatConfig.getChatAggRichConfig().getChatDefaultConfig();
|
||||
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(parseInfo.getViewId());
|
||||
TimeDefaultConfig timeDefaultConfig = viewSchema.getMetricTypeTimeDefaultConfig();
|
||||
DateConf dateInfo = new DateConf();
|
||||
int unit = 1;
|
||||
if (Objects.nonNull(defaultConfig) && Objects.nonNull(defaultConfig.getUnit())) {
|
||||
unit = defaultConfig.getUnit();
|
||||
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())) {
|
||||
int unit = timeDefaultConfig.getUnit();
|
||||
String startDate = LocalDate.now().plusDays(-unit).toString();
|
||||
String endDate = startDate;
|
||||
if (TimeMode.LAST.equals(timeDefaultConfig.getTimeMode())) {
|
||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
} else if (TimeMode.RECENT.equals(timeDefaultConfig.getTimeMode())) {
|
||||
dateInfo.setDateMode(DateConf.DateMode.RECENT);
|
||||
endDate = LocalDate.now().plusDays(-1).toString();
|
||||
}
|
||||
dateInfo.setUnit(unit);
|
||||
dateInfo.setPeriod(timeDefaultConfig.getPeriod());
|
||||
dateInfo.setStartDate(startDate);
|
||||
dateInfo.setEndDate(endDate);
|
||||
}
|
||||
String startDate = LocalDate.now().plusDays(-unit).toString();
|
||||
String endDate = startDate;
|
||||
|
||||
if (ChatDefaultConfigReq.TimeMode.LAST.equals(defaultConfig.getTimeMode())) {
|
||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
} else if (ChatDefaultConfigReq.TimeMode.RECENT.equals(defaultConfig.getTimeMode())) {
|
||||
dateInfo.setDateMode(DateConf.DateMode.RECENT);
|
||||
endDate = LocalDate.now().plusDays(-1).toString();
|
||||
}
|
||||
dateInfo.setUnit(unit);
|
||||
dateInfo.setPeriod(defaultConfig.getPeriod());
|
||||
dateInfo.setStartDate(startDate);
|
||||
dateInfo.setEndDate(endDate);
|
||||
|
||||
parseInfo.setDateInfo(dateInfo);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
package com.tencent.supersonic.chat.core.query.rule.tag;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public abstract class TagListQuery extends TagSemanticQuery {
|
||||
|
||||
@@ -23,28 +25,29 @@ public abstract class TagListQuery extends TagSemanticQuery {
|
||||
}
|
||||
|
||||
private void addEntityDetailAndOrderByMetric(QueryContext queryContext, SemanticParseInfo parseInfo) {
|
||||
Long modelId = parseInfo.getModelId();
|
||||
if (Objects.nonNull(modelId) && modelId > 0L) {
|
||||
ChatConfigRichResp chaConfigRichDesc = queryContext.getModelIdToChatRichConfig().get(modelId);
|
||||
ModelSchema modelSchema = queryContext.getSemanticSchema().getModelSchemaMap().get(parseInfo.getModelId());
|
||||
if (chaConfigRichDesc != null && chaConfigRichDesc.getChatDetailRichConfig() != null
|
||||
&& Objects.nonNull(modelSchema) && Objects.nonNull(modelSchema.getEntity())) {
|
||||
Long viewId = parseInfo.getViewId();
|
||||
if (Objects.nonNull(viewId) && viewId > 0L) {
|
||||
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(viewId);
|
||||
if (viewSchema != null && Objects.nonNull(viewSchema.getEntity())) {
|
||||
Set<SchemaElement> dimensions = new LinkedHashSet<>();
|
||||
Set<SchemaElement> metrics = new LinkedHashSet();
|
||||
Set<Order> orders = new LinkedHashSet();
|
||||
ChatDefaultRichConfigResp chatDefaultConfig = chaConfigRichDesc
|
||||
.getChatDetailRichConfig().getChatDefaultConfig();
|
||||
if (chatDefaultConfig != null) {
|
||||
if (CollectionUtils.isNotEmpty(chatDefaultConfig.getMetrics())) {
|
||||
chatDefaultConfig.getMetrics().stream()
|
||||
.forEach(metric -> {
|
||||
metrics.add(metric);
|
||||
orders.add(new Order(metric.getBizName(), Constants.DESC_UPPER));
|
||||
});
|
||||
Set<SchemaElement> metrics = new LinkedHashSet<>();
|
||||
Set<Order> orders = new LinkedHashSet<>();
|
||||
TagTypeDefaultConfig tagTypeDefaultConfig = viewSchema.getTagTypeDefaultConfig();
|
||||
if (tagTypeDefaultConfig != null) {
|
||||
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getMetricIds())) {
|
||||
metrics = tagTypeDefaultConfig.getMetricIds()
|
||||
.stream().map(id -> {
|
||||
SchemaElement metric = viewSchema.getElement(SchemaElementType.METRIC, id);
|
||||
if (metric != null) {
|
||||
orders.add(new Order(metric.getBizName(), Constants.DESC_UPPER));
|
||||
}
|
||||
return metric;
|
||||
}).filter(Objects::nonNull).collect(Collectors.toSet());
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(chatDefaultConfig.getDimensions())) {
|
||||
chatDefaultConfig.getDimensions().stream()
|
||||
.forEach(dimension -> dimensions.add(dimension));
|
||||
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getDimensionIds())) {
|
||||
dimensions = tagTypeDefaultConfig.getDimensionIds().stream()
|
||||
.map(id -> viewSchema.getElement(SchemaElementType.DIMENSION, id))
|
||||
.filter(Objects::nonNull).collect(Collectors.toSet());
|
||||
}
|
||||
}
|
||||
parseInfo.setDimensions(dimensions);
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
package com.tencent.supersonic.chat.core.query.rule.tag;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
@Slf4j
|
||||
public abstract class TagSemanticQuery extends RuleSemanticQuery {
|
||||
@@ -40,19 +42,24 @@ public abstract class TagSemanticQuery extends RuleSemanticQuery {
|
||||
parseInfo.setQueryType(QueryType.TAG);
|
||||
parseInfo.setLimit(TAG_MAX_RESULTS);
|
||||
if (parseInfo.getDateInfo() == null) {
|
||||
ChatConfigRichResp chatConfig = queryContext.getModelIdToChatRichConfig().get(parseInfo.getModelId());
|
||||
ChatDefaultRichConfigResp defaultConfig = chatConfig.getChatDetailRichConfig().getChatDefaultConfig();
|
||||
|
||||
int unit = 1;
|
||||
if (Objects.nonNull(defaultConfig) && Objects.nonNull(defaultConfig.getUnit())) {
|
||||
unit = defaultConfig.getUnit();
|
||||
}
|
||||
String date = LocalDate.now().plusDays(-unit).toString();
|
||||
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(parseInfo.getViewId());
|
||||
TimeDefaultConfig timeDefaultConfig = viewSchema.getTagTypeTimeDefaultConfig();
|
||||
DateConf dateInfo = new DateConf();
|
||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
dateInfo.setStartDate(date);
|
||||
dateInfo.setEndDate(date);
|
||||
|
||||
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())) {
|
||||
int unit = timeDefaultConfig.getUnit();
|
||||
String startDate = LocalDate.now().plusDays(-unit).toString();
|
||||
String endDate = startDate;
|
||||
if (TimeMode.LAST.equals(timeDefaultConfig.getTimeMode())) {
|
||||
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
} else if (TimeMode.RECENT.equals(timeDefaultConfig.getTimeMode())) {
|
||||
dateInfo.setDateMode(DateConf.DateMode.RECENT);
|
||||
endDate = LocalDate.now().plusDays(-1).toString();
|
||||
}
|
||||
dateInfo.setUnit(unit);
|
||||
dateInfo.setPeriod(timeDefaultConfig.getPeriod());
|
||||
dateInfo.setStartDate(startDate);
|
||||
dateInfo.setEndDate(endDate);
|
||||
}
|
||||
parseInfo.setDateInfo(dateInfo);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.core.utils;
|
||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.parser.JavaLLMProxy;
|
||||
import com.tencent.supersonic.chat.core.parser.LLMProxy;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.ModelResolver;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.ViewResolver;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@@ -16,7 +16,7 @@ public class ComponentFactory {
|
||||
|
||||
private static SemanticInterpreter semanticInterpreter;
|
||||
private static LLMProxy llmProxy;
|
||||
private static ModelResolver modelResolver;
|
||||
private static ViewResolver modelResolver;
|
||||
|
||||
public static SemanticInterpreter getSemanticLayer() {
|
||||
if (Objects.isNull(semanticInterpreter)) {
|
||||
@@ -44,9 +44,9 @@ public class ComponentFactory {
|
||||
return llmProxy;
|
||||
}
|
||||
|
||||
public static ModelResolver getModelResolver() {
|
||||
public static ViewResolver getModelResolver() {
|
||||
if (Objects.isNull(modelResolver)) {
|
||||
modelResolver = init(ModelResolver.class);
|
||||
modelResolver = init(ViewResolver.class);
|
||||
}
|
||||
return modelResolver;
|
||||
}
|
||||
|
||||
@@ -1,15 +1,9 @@
|
||||
package com.tencent.supersonic.chat.core.utils;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.AND_UPPER;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.APOSTROPHE;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.COMMA;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.SPACE;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE_DOUBLE;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.config.DefaultMetric;
|
||||
import com.tencent.supersonic.chat.core.config.Dim4Dict;
|
||||
import com.tencent.supersonic.chat.core.knowledge.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
@@ -18,8 +12,15 @@ import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
@@ -27,11 +28,12 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.StringJoiner;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.AND_UPPER;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.APOSTROPHE;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.COMMA;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.SPACE;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE_DOUBLE;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@@ -82,7 +84,7 @@ public class DictQueryHelper {
|
||||
|
||||
if (!CollectionUtils.isEmpty(columns)) {
|
||||
for (QueryColumn column : columns) {
|
||||
if (Strings.isNotEmpty(column.getNameEn())) {
|
||||
if (StringUtils.isNotEmpty(column.getNameEn())) {
|
||||
String nameEn = column.getNameEn();
|
||||
if (nameEn.endsWith(UNDERLINE_DOUBLE + bizName)) {
|
||||
dimNameRewrite = nameEn;
|
||||
@@ -159,9 +161,6 @@ public class DictQueryHelper {
|
||||
private QueryStructReq generateQueryStructCmd(Long modelId, DefaultMetric defaultMetricDesc, Dim4Dict dim4Dict) {
|
||||
QueryStructReq queryStructCmd = new QueryStructReq();
|
||||
|
||||
queryStructCmd.addModelId(modelId);
|
||||
queryStructCmd.setGroups(Arrays.asList(dim4Dict.getBizName()));
|
||||
|
||||
List<Filter> filters = generateFilters(dim4Dict, queryStructCmd);
|
||||
queryStructCmd.setDimensionFilters(filters);
|
||||
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.utils;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.pojo.ModelCluster;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class ModelClusterBuilder {
|
||||
|
||||
public static List<ModelCluster> buildModelClusters(SemanticSchema semanticSchema) {
|
||||
Map<Long, ModelSchema> modelMap = semanticSchema.getModelSchemaMap();
|
||||
Set<Long> visited = new HashSet<>();
|
||||
List<Set<Long>> modelClusters = new ArrayList<>();
|
||||
for (ModelSchema model : modelMap.values()) {
|
||||
if (!visited.contains(model.getModel().getModel())) {
|
||||
Set<Long> modelCluster = new HashSet<>();
|
||||
dfs(model, modelMap, visited, modelCluster);
|
||||
modelClusters.add(modelCluster);
|
||||
}
|
||||
}
|
||||
return modelClusters.stream().map(ModelCluster::build).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static void dfs(ModelSchema model, Map<Long, ModelSchema> modelMap,
|
||||
Set<Long> visited, Set<Long> modelCluster) {
|
||||
visited.add(model.getModel().getModel());
|
||||
modelCluster.add(model.getModel().getModel());
|
||||
for (Long neighborId : model.getModelClusterSet()) {
|
||||
if (!visited.contains(neighborId)) {
|
||||
dfs(modelMap.get(neighborId), modelMap, visited, modelCluster);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -3,8 +3,11 @@ package com.tencent.supersonic.chat.core.utils;
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.core.knowledge.ModelInfoStat;
|
||||
import com.tencent.supersonic.chat.core.knowledge.ViewInfoStat;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -12,8 +15,6 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* nature parse helper
|
||||
@@ -37,8 +38,8 @@ public class NatureHelper {
|
||||
case ENTITY:
|
||||
result = SchemaElementType.ENTITY;
|
||||
break;
|
||||
case MODEL:
|
||||
result = SchemaElementType.MODEL;
|
||||
case VIEW:
|
||||
result = SchemaElementType.VIEW;
|
||||
break;
|
||||
case VALUE:
|
||||
result = SchemaElementType.VALUE;
|
||||
@@ -52,12 +53,12 @@ public class NatureHelper {
|
||||
return result;
|
||||
}
|
||||
|
||||
private static boolean isModelOrEntity(Term term, Integer model) {
|
||||
private static boolean isViewOrEntity(Term term, Integer model) {
|
||||
return (DictWordType.NATURE_SPILT + model).equals(term.nature.toString()) || term.nature.toString()
|
||||
.endsWith(DictWordType.ENTITY.getType());
|
||||
}
|
||||
|
||||
public static Integer getModelByNature(Nature nature) {
|
||||
public static Integer getViewByNature(Nature nature) {
|
||||
if (nature.startsWith(DictWordType.NATURE_SPILT)) {
|
||||
String[] dimensionValues = nature.toString().split(DictWordType.NATURE_SPILT);
|
||||
if (StringUtils.isNumeric(dimensionValues[1])) {
|
||||
@@ -67,7 +68,7 @@ public class NatureHelper {
|
||||
return 0;
|
||||
}
|
||||
|
||||
public static Long getModelId(String nature) {
|
||||
public static Long getViewId(String nature) {
|
||||
try {
|
||||
String[] split = nature.split(DictWordType.NATURE_SPILT);
|
||||
if (split.length <= 1) {
|
||||
@@ -80,7 +81,7 @@ public class NatureHelper {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static boolean isDimensionValueModelId(String nature) {
|
||||
public static boolean isDimensionValueViewId(String nature) {
|
||||
if (StringUtils.isEmpty(nature)) {
|
||||
return false;
|
||||
}
|
||||
@@ -95,21 +96,21 @@ public class NatureHelper {
|
||||
&& StringUtils.isNumeric(split[1]);
|
||||
}
|
||||
|
||||
public static ModelInfoStat getModelStat(List<Term> terms) {
|
||||
return ModelInfoStat.builder()
|
||||
.modelCount(getModelCount(terms))
|
||||
.dimensionModelCount(getDimensionCount(terms))
|
||||
.metricModelCount(getMetricCount(terms))
|
||||
.dimensionValueModelCount(getDimensionValueCount(terms))
|
||||
public static ViewInfoStat getViewStat(List<Term> terms) {
|
||||
return ViewInfoStat.builder()
|
||||
.viewCount(getViewCount(terms))
|
||||
.dimensionViewCount(getDimensionCount(terms))
|
||||
.metricViewCount(getMetricCount(terms))
|
||||
.dimensionValueViewCount(getDimensionValueCount(terms))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static long getModelCount(List<Term> terms) {
|
||||
return terms.stream().filter(term -> isModelOrEntity(term, getModelByNature(term.nature))).count();
|
||||
private static long getViewCount(List<Term> terms) {
|
||||
return terms.stream().filter(term -> isViewOrEntity(term, getViewByNature(term.nature))).count();
|
||||
}
|
||||
|
||||
private static long getDimensionValueCount(List<Term> terms) {
|
||||
return terms.stream().filter(term -> isDimensionValueModelId(term.nature.toString())).count();
|
||||
return terms.stream().filter(term -> isDimensionValueViewId(term.nature.toString())).count();
|
||||
}
|
||||
|
||||
private static long getDimensionCount(List<Term> terms) {
|
||||
@@ -129,13 +130,13 @@ public class NatureHelper {
|
||||
* @param terms
|
||||
* @return
|
||||
*/
|
||||
public static Map<Long, Map<DictWordType, Integer>> getModelToNatureStat(List<Term> terms) {
|
||||
public static Map<Long, Map<DictWordType, Integer>> getViewToNatureStat(List<Term> terms) {
|
||||
Map<Long, Map<DictWordType, Integer>> modelToNature = new HashMap<>();
|
||||
terms.stream().filter(
|
||||
term -> term.nature.startsWith(DictWordType.NATURE_SPILT)
|
||||
).forEach(term -> {
|
||||
DictWordType dictWordType = DictWordType.getNatureType(String.valueOf(term.nature));
|
||||
Long model = getModelId(String.valueOf(term.nature));
|
||||
Long model = getViewId(String.valueOf(term.nature));
|
||||
|
||||
Map<DictWordType, Integer> natureTypeMap = new HashMap<>();
|
||||
natureTypeMap.put(dictWordType, 1);
|
||||
@@ -156,15 +157,15 @@ public class NatureHelper {
|
||||
return modelToNature;
|
||||
}
|
||||
|
||||
public static List<Long> selectPossibleModels(List<Term> terms) {
|
||||
Map<Long, Map<DictWordType, Integer>> modelToNatureStat = getModelToNatureStat(terms);
|
||||
Integer maxModelTypeSize = modelToNatureStat.entrySet().stream()
|
||||
public static List<Long> selectPossibleViews(List<Term> terms) {
|
||||
Map<Long, Map<DictWordType, Integer>> modelToNatureStat = getViewToNatureStat(terms);
|
||||
Integer maxViewTypeSize = modelToNatureStat.entrySet().stream()
|
||||
.max(Comparator.comparingInt(o -> o.getValue().size())).map(entry -> entry.getValue().size())
|
||||
.orElse(null);
|
||||
if (Objects.isNull(maxModelTypeSize) || maxModelTypeSize == 0) {
|
||||
if (Objects.isNull(maxViewTypeSize) || maxViewTypeSize == 0) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return modelToNatureStat.entrySet().stream().filter(entry -> entry.getValue().size() == maxModelTypeSize)
|
||||
return modelToNatureStat.entrySet().stream().filter(entry -> entry.getValue().size() == maxViewTypeSize)
|
||||
.map(entry -> entry.getKey()).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
@@ -36,39 +36,40 @@ import java.util.stream.Collectors;
|
||||
public class QueryReqBuilder {
|
||||
|
||||
public static QueryStructReq buildStructReq(SemanticParseInfo parseInfo) {
|
||||
QueryStructReq queryStructCmd = new QueryStructReq();
|
||||
queryStructCmd.setModelIds(parseInfo.getModel().getModelIds());
|
||||
queryStructCmd.setQueryType(parseInfo.getQueryType());
|
||||
queryStructCmd.setDateInfo(rewrite2Between(parseInfo.getDateInfo()));
|
||||
QueryStructReq queryStructReq = new QueryStructReq();
|
||||
queryStructReq.setViewId(parseInfo.getViewId());
|
||||
queryStructReq.setViewName(parseInfo.getView().getName());
|
||||
queryStructReq.setQueryType(parseInfo.getQueryType());
|
||||
queryStructReq.setDateInfo(rewrite2Between(parseInfo.getDateInfo()));
|
||||
|
||||
List<Filter> dimensionFilters = parseInfo.getDimensionFilters().stream()
|
||||
.filter(chatFilter -> Strings.isNotEmpty(chatFilter.getBizName()))
|
||||
.map(chatFilter -> new Filter(chatFilter.getBizName(), chatFilter.getOperator(), chatFilter.getValue()))
|
||||
.collect(Collectors.toList());
|
||||
queryStructCmd.setDimensionFilters(dimensionFilters);
|
||||
queryStructReq.setDimensionFilters(dimensionFilters);
|
||||
|
||||
List<Filter> metricFilters = parseInfo.getMetricFilters().stream()
|
||||
.map(chatFilter -> new Filter(chatFilter.getBizName(), chatFilter.getOperator(), chatFilter.getValue()))
|
||||
.collect(Collectors.toList());
|
||||
queryStructCmd.setMetricFilters(metricFilters);
|
||||
queryStructReq.setMetricFilters(metricFilters);
|
||||
|
||||
addDateDimension(parseInfo);
|
||||
List<String> dimensions = parseInfo.getDimensions().stream().map(SchemaElement::getBizName)
|
||||
.collect(Collectors.toList());
|
||||
queryStructCmd.setGroups(dimensions);
|
||||
queryStructCmd.setLimit(parseInfo.getLimit());
|
||||
queryStructReq.setGroups(dimensions);
|
||||
queryStructReq.setLimit(parseInfo.getLimit());
|
||||
// only one metric is queried at once
|
||||
Set<SchemaElement> metrics = parseInfo.getMetrics();
|
||||
if (!CollectionUtils.isEmpty(metrics)) {
|
||||
SchemaElement metricElement = parseInfo.getMetrics().iterator().next();
|
||||
Set<Order> order = getOrder(parseInfo.getOrders(), parseInfo.getAggType(), metricElement);
|
||||
queryStructCmd.setAggregators(getAggregatorByMetric(parseInfo.getAggType(), metricElement));
|
||||
queryStructCmd.setOrders(new ArrayList<>(order));
|
||||
queryStructReq.setAggregators(getAggregatorByMetric(parseInfo.getAggType(), metricElement));
|
||||
queryStructReq.setOrders(new ArrayList<>(order));
|
||||
}
|
||||
|
||||
deletionDuplicated(queryStructCmd);
|
||||
deletionDuplicated(queryStructReq);
|
||||
|
||||
return queryStructCmd;
|
||||
return queryStructReq;
|
||||
}
|
||||
|
||||
private static void deletionDuplicated(QueryStructReq queryStructReq) {
|
||||
@@ -118,7 +119,7 @@ public class QueryReqBuilder {
|
||||
for (Filter dimensionFilter : queryStructReq.getDimensionFilters()) {
|
||||
QueryStructReq req = new QueryStructReq();
|
||||
BeanUtils.copyProperties(queryStructReq, req);
|
||||
req.setModelIds(new HashSet<>(queryStructReq.getModelIds()));
|
||||
req.setViewId(parseInfo.getViewId());
|
||||
req.setDimensionFilters(Lists.newArrayList(dimensionFilter));
|
||||
queryStructReqs.add(req);
|
||||
}
|
||||
@@ -130,16 +131,16 @@ public class QueryReqBuilder {
|
||||
* convert to QueryS2SQLReq
|
||||
*
|
||||
* @param querySql
|
||||
* @param modelIds
|
||||
* @param viewId
|
||||
* @return
|
||||
*/
|
||||
public static QuerySqlReq buildS2SQLReq(String querySql, Set<Long> modelIds) {
|
||||
QuerySqlReq querySqlReq = new QuerySqlReq();
|
||||
public static QuerySqlReq buildS2SQLReq(String querySql, Long viewId) {
|
||||
QuerySqlReq querySQLReq = new QuerySqlReq();
|
||||
if (Objects.nonNull(querySql)) {
|
||||
querySqlReq.setSql(querySql);
|
||||
querySQLReq.setSql(querySql);
|
||||
}
|
||||
querySqlReq.setModelIds(modelIds);
|
||||
return querySqlReq;
|
||||
querySQLReq.setViewId(viewId);
|
||||
return querySQLReq;
|
||||
}
|
||||
|
||||
private static List<Aggregator> getAggregatorByMetric(AggregateTypeEnum aggregateType, SchemaElement metric) {
|
||||
@@ -234,14 +235,14 @@ public class QueryReqBuilder {
|
||||
|
||||
public static QueryStructReq buildStructRatioReq(SemanticParseInfo parseInfo, SchemaElement metric,
|
||||
AggOperatorEnum aggOperatorEnum) {
|
||||
QueryStructReq queryStructCmd = buildStructReq(parseInfo);
|
||||
queryStructCmd.setQueryType(QueryType.METRIC);
|
||||
queryStructCmd.setOrders(new ArrayList<>());
|
||||
QueryStructReq queryStructReq = buildStructReq(parseInfo);
|
||||
queryStructReq.setQueryType(QueryType.METRIC);
|
||||
queryStructReq.setOrders(new ArrayList<>());
|
||||
List<Aggregator> aggregators = new ArrayList<>();
|
||||
Aggregator ratioRoll = new Aggregator(metric.getBizName(), aggOperatorEnum);
|
||||
aggregators.add(ratioRoll);
|
||||
queryStructCmd.setAggregators(aggregators);
|
||||
return queryStructCmd;
|
||||
queryStructReq.setAggregators(aggregators);
|
||||
return queryStructReq;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ public class SimilarQueryManager {
|
||||
embeddingQuery.setQuery(queryText);
|
||||
|
||||
Map<String, Object> metaData = new HashMap<>();
|
||||
metaData.put("modelId", (similarQueryReq.getModelId()));
|
||||
metaData.put("modelId", similarQueryReq.getViewId());
|
||||
metaData.put("agentId", similarQueryReq.getAgentId());
|
||||
embeddingQuery.setMetadata(metaData);
|
||||
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
package com.tencent.supersonic.chat.core.s2sql;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.mockito.Mockito;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
class LLMSqlParserTest {
|
||||
|
||||
@@ -28,7 +29,7 @@ class LLMSqlParserTest {
|
||||
SchemaElement schemaElement = SchemaElement.builder()
|
||||
.bizName("singer_name")
|
||||
.name("歌手名")
|
||||
.model(2L)
|
||||
.view(2L)
|
||||
.schemaValueMaps(schemaValueMaps)
|
||||
.build();
|
||||
dimensions.add(schemaElement);
|
||||
@@ -36,7 +37,7 @@ class LLMSqlParserTest {
|
||||
SchemaElement schemaElement2 = SchemaElement.builder()
|
||||
.bizName("publish_time")
|
||||
.name("发布时间")
|
||||
.model(2L)
|
||||
.view(2L)
|
||||
.build();
|
||||
dimensions.add(schemaElement2);
|
||||
|
||||
@@ -46,7 +47,7 @@ class LLMSqlParserTest {
|
||||
SchemaElement metric = SchemaElement.builder()
|
||||
.bizName("play_count")
|
||||
.name("播放量")
|
||||
.model(2L)
|
||||
.view(2L)
|
||||
.build();
|
||||
metrics.add(metric);
|
||||
|
||||
|
||||
@@ -32,8 +32,8 @@ public class SchemaDictUpdateListener implements ApplicationListener<DataEvent>
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(dataItem.getName());
|
||||
String sign = DictWordType.NATURE_SPILT;
|
||||
String nature = sign + dataItem.getModelId() + sign + dataItem.getId()
|
||||
+ sign + dataItem.getType().getName();
|
||||
String nature = sign + 1 + sign + dataItem.getId()
|
||||
+ sign + dataItem.getType().name().toLowerCase();
|
||||
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
|
||||
dictWord.setNature(nature);
|
||||
dictWord.setNatureWithFrequency(natureWithFrequency);
|
||||
|
||||
@@ -1,256 +1,40 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import com.baomidou.mybatisplus.annotation.IdType;
|
||||
import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import lombok.Data;
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@TableName("s2_plugin")
|
||||
public class PluginDO {
|
||||
/**
|
||||
*
|
||||
*/
|
||||
|
||||
@TableId(type = IdType.AUTO)
|
||||
private Long id;
|
||||
|
||||
/**
|
||||
* DASHBOARD,WIDGET,URL
|
||||
*/
|
||||
private String type;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String model;
|
||||
private String view;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String pattern;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String parseMode;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String name;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Date createdAt;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String createdBy;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private Date updatedAt;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String updatedBy;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String parseModeConfig;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String config;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
private String comment;
|
||||
|
||||
/**
|
||||
* @return id
|
||||
*/
|
||||
public Long getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param id
|
||||
*/
|
||||
public void setId(Long id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
/**
|
||||
* DASHBOARD,WIDGET,URL
|
||||
*
|
||||
* @return type DASHBOARD,WIDGET,URL
|
||||
*/
|
||||
public String getType() {
|
||||
return type;
|
||||
}
|
||||
|
||||
/**
|
||||
* DASHBOARD,WIDGET,URL
|
||||
*
|
||||
* @param type DASHBOARD,WIDGET,URL
|
||||
*/
|
||||
public void setType(String type) {
|
||||
this.type = type == null ? null : type.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return model
|
||||
*/
|
||||
public String getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param model
|
||||
*/
|
||||
public void setModel(String model) {
|
||||
this.model = model == null ? null : model.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return pattern
|
||||
*/
|
||||
public String getPattern() {
|
||||
return pattern;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param pattern
|
||||
*/
|
||||
public void setPattern(String pattern) {
|
||||
this.pattern = pattern == null ? null : pattern.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return parse_mode
|
||||
*/
|
||||
public String getParseMode() {
|
||||
return parseMode;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param parseMode
|
||||
*/
|
||||
public void setParseMode(String parseMode) {
|
||||
this.parseMode = parseMode == null ? null : parseMode.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return name
|
||||
*/
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param name
|
||||
*/
|
||||
public void setName(String name) {
|
||||
this.name = name == null ? null : name.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return created_at
|
||||
*/
|
||||
public Date getCreatedAt() {
|
||||
return createdAt;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param createdAt
|
||||
*/
|
||||
public void setCreatedAt(Date createdAt) {
|
||||
this.createdAt = createdAt;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return created_by
|
||||
*/
|
||||
public String getCreatedBy() {
|
||||
return createdBy;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param createdBy
|
||||
*/
|
||||
public void setCreatedBy(String createdBy) {
|
||||
this.createdBy = createdBy == null ? null : createdBy.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return updated_at
|
||||
*/
|
||||
public Date getUpdatedAt() {
|
||||
return updatedAt;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param updatedAt
|
||||
*/
|
||||
public void setUpdatedAt(Date updatedAt) {
|
||||
this.updatedAt = updatedAt;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return updated_by
|
||||
*/
|
||||
public String getUpdatedBy() {
|
||||
return updatedBy;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param updatedBy
|
||||
*/
|
||||
public void setUpdatedBy(String updatedBy) {
|
||||
this.updatedBy = updatedBy == null ? null : updatedBy.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return parse_mode_config
|
||||
*/
|
||||
public String getParseModeConfig() {
|
||||
return parseModeConfig;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param parseModeConfig
|
||||
*/
|
||||
public void setParseModeConfig(String parseModeConfig) {
|
||||
this.parseModeConfig = parseModeConfig == null ? null : parseModeConfig.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return config
|
||||
*/
|
||||
public String getConfig() {
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param config
|
||||
*/
|
||||
public void setConfig(String config) {
|
||||
this.config = config == null ? null : config.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* @return comment
|
||||
*/
|
||||
public String getComment() {
|
||||
return comment;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param comment
|
||||
*/
|
||||
public void setComment(String comment) {
|
||||
this.comment = comment == null ? null : comment.trim();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,69 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDOExample;
|
||||
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface PluginDOMapper {
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
long countByExample(PluginDOExample example);
|
||||
public interface PluginDOMapper extends BaseMapper<PluginDO> {
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int deleteByPrimaryKey(Long id);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int insert(PluginDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int insertSelective(PluginDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
List<PluginDO> selectByExampleWithBLOBs(PluginDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
List<PluginDO> selectByExample(PluginDOExample example);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
PluginDO selectByPrimaryKey(Long id);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByPrimaryKeySelective(PluginDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByPrimaryKeyWithBLOBs(PluginDO record);
|
||||
|
||||
/**
|
||||
*
|
||||
* @mbg.generated
|
||||
*/
|
||||
int updateByPrimaryKey(PluginDO record);
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDOExample;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDO;
|
||||
|
||||
import java.util.List;
|
||||
@@ -16,7 +16,7 @@ public interface PluginRepository {
|
||||
|
||||
PluginDO getPlugin(Long id);
|
||||
|
||||
List<PluginDO> query(PluginDOExample pluginDOExample);
|
||||
List<PluginDO> query(QueryWrapper<PluginDO> queryWrapper);
|
||||
|
||||
void deletePlugin(Long id);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDOExample;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.PluginDOMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.PluginRepository;
|
||||
@@ -26,7 +26,7 @@ public class PluginRepositoryImpl implements PluginRepository {
|
||||
|
||||
@Override
|
||||
public List<PluginDO> getPlugins() {
|
||||
return pluginDOMapper.selectByExampleWithBLOBs(new PluginDOExample());
|
||||
return pluginDOMapper.selectList(new QueryWrapper<>());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -60,22 +60,22 @@ public class PluginRepositoryImpl implements PluginRepository {
|
||||
|
||||
@Override
|
||||
public void updatePlugin(PluginDO pluginDO) {
|
||||
pluginDOMapper.updateByPrimaryKeyWithBLOBs(pluginDO);
|
||||
pluginDOMapper.updateById(pluginDO);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PluginDO getPlugin(Long id) {
|
||||
return pluginDOMapper.selectByPrimaryKey(id);
|
||||
return pluginDOMapper.selectById(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<PluginDO> query(PluginDOExample pluginDOExample) {
|
||||
return pluginDOMapper.selectByExampleWithBLOBs(pluginDOExample);
|
||||
public List<PluginDO> query(QueryWrapper<PluginDO> queryWrapper) {
|
||||
return pluginDOMapper.selectList(queryWrapper);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deletePlugin(Long id) {
|
||||
pluginDOMapper.deleteByPrimaryKey(id);
|
||||
pluginDOMapper.deleteById(id);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.RelatedSchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.service.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* DimensionRecommendProcessor recommend some dimensions
|
||||
@@ -33,13 +34,13 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
|
||||
return;
|
||||
}
|
||||
SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
|
||||
List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getModel());
|
||||
List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getView());
|
||||
queryResult.setRecommendedDimensions(dimensionRecommended);
|
||||
}
|
||||
|
||||
private List<SchemaElement> getDimensions(Long metricId, Long modelId) {
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
|
||||
ViewSchema modelSchema = semanticService.getModelSchema(modelId);
|
||||
List<Long> drillDownDimensions = Lists.newArrayList();
|
||||
Set<SchemaElement> metricElements = modelSchema.getMetrics();
|
||||
if (!CollectionUtils.isEmpty(metricElements)) {
|
||||
|
||||
@@ -14,6 +14,8 @@ import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -21,7 +23,6 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* MetricRecommendProcessor fills recommended metrics based on embedding similarity.
|
||||
@@ -45,7 +46,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
}
|
||||
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getModel().toString());
|
||||
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getView().toString());
|
||||
filterCondition.put("type", SchemaElementType.METRIC.name());
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
|
||||
.filterCondition(filterCondition).queryEmbeddings(null).build();
|
||||
@@ -70,9 +71,9 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
if (!metricIds.contains(Retrieval.getLongId(retrieval.getId()))) {
|
||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
|
||||
SchemaElement.class);
|
||||
if (retrieval.getMetadata().containsKey("modelId")) {
|
||||
String modelId = retrieval.getMetadata().get("modelId").toString();
|
||||
schemaElement.setModel(Long.parseLong(modelId));
|
||||
if (retrieval.getMetadata().containsKey("viewId")) {
|
||||
String viewId = retrieval.getMetadata().get("viewId").toString();
|
||||
schemaElement.setView(Long.parseLong(viewId));
|
||||
}
|
||||
schemaElement.setOrder(++metricOrder);
|
||||
parseInfo.getMetrics().add(schemaElement);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
@@ -10,9 +11,10 @@ import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.llm.analytics.MetricAnalyzeQuery;
|
||||
import com.tencent.supersonic.chat.server.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* EntityInfoProcessor fills core attributes of an entity so that
|
||||
@@ -35,8 +37,9 @@ public class EntityInfoProcessor implements ParseResultProcessor {
|
||||
return;
|
||||
}
|
||||
//1. set entity info
|
||||
ViewSchema viewSchema = queryContext.getSemanticSchema().getViewSchemaMap().get(parseInfo.getViewId());
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryContext.getUser());
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, viewSchema, queryContext.getUser());
|
||||
if (QueryManager.isTagQuery(queryMode)
|
||||
|| QueryManager.isMetricQuery(queryMode)) {
|
||||
parseInfo.setEntityInfo(entityInfo);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user