(improvement) Move out the datasource and merge the datasource with the model, and adapt the chat module (#423)

Co-authored-by: jolunoluo <jolunoluo@tencent.com>
This commit is contained in:
jipeli
2023-11-27 11:05:24 +08:00
committed by GitHub
parent 0534053ff9
commit 27bb1b322e
190 changed files with 3900 additions and 10561 deletions

View File

@@ -9,8 +9,8 @@ import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;

View File

@@ -1,8 +1,13 @@
package com.tencent.supersonic.chat.api.pojo;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.pojo.ModelRela;
import lombok.Data;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
@@ -14,8 +19,8 @@ public class ModelSchema {
private Set<SchemaElement> dimensions = new HashSet<>();
private Set<SchemaElement> dimensionValues = new HashSet<>();
private Set<SchemaElement> tags = new HashSet<>();
@Deprecated
private SchemaElement entity = new SchemaElement();
private List<ModelRela> modelRelas = new ArrayList<>();
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
Optional<SchemaElement> element = Optional.empty();
@@ -75,4 +80,16 @@ public class ModelSchema {
}
}
public Set<Long> getModelClusterSet() {
if (CollectionUtils.isEmpty(modelRelas)) {
return Sets.newHashSet();
}
Set<Long> modelClusterSet = new HashSet<>();
modelRelas.forEach(modelRela -> {
modelClusterSet.add(modelRela.getToModelId());
modelClusterSet.add(modelRela.getFromModelId());
});
return modelClusterSet;
}
}

View File

@@ -15,6 +15,7 @@ public class QueryContext {
private QueryReq request;
private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
public QueryContext(QueryReq request) {
this.request = request;

View File

@@ -6,7 +6,7 @@ public enum SchemaElementType {
DIMENSION,
VALUE,
ENTITY,
TAG,
ID,
DATE,
TAG
DATE
}

View File

@@ -0,0 +1,61 @@
package com.tencent.supersonic.chat.api.pojo;
import com.clickhouse.client.internal.apache.commons.compress.utils.Lists;
import com.tencent.supersonic.common.pojo.ModelCluster;
import lombok.Data;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
@Data
public class SchemaModelClusterMapInfo {
private Map<String, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
public Set<String> getMatchedModelClusters() {
return modelElementMatches.keySet();
}
public List<SchemaElementMatch> getMatchedElements(Long modelId) {
for (String key : modelElementMatches.keySet()) {
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
return modelElementMatches.get(key);
}
}
return Lists.newArrayList();
}
public List<SchemaElementMatch> getMatchedElements(String modelCluster) {
return modelElementMatches.get(modelCluster);
}
public Map<String, List<SchemaElementMatch>> getModelElementMatches() {
return modelElementMatches;
}
public Map<String, List<SchemaElementMatch>> getElementMatchesByModelIds(Set<Long> modelIds) {
if (CollectionUtils.isEmpty(modelIds)) {
return modelElementMatches;
}
Map<String, List<SchemaElementMatch>> modelElementMatchesFiltered = new HashMap<>();
for (String key : modelElementMatches.keySet()) {
for (Long modelId : modelIds) {
if (ModelCluster.getModelIdFromKey(key).contains(modelId)) {
modelElementMatchesFiltered.put(key, modelElementMatches.get(key));
}
}
}
return modelElementMatchesFiltered;
}
public void setModelElementMatches(Map<String, List<SchemaElementMatch>> modelElementMatches) {
this.modelElementMatches = modelElementMatches;
}
public void setMatchedElements(String modelCluster, List<SchemaElementMatch> elementMatches) {
modelElementMatches.put(modelCluster, elementMatches);
}
}

View File

@@ -5,9 +5,13 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterType;
import lombok.Data;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
@@ -16,15 +20,13 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import com.tencent.supersonic.common.pojo.enums.FilterType;
import lombok.Data;
@Data
public class SemanticParseInfo {
private Integer id;
private String queryMode;
private SchemaElement model;
private ModelCluster model = new ModelCluster();
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
private Set<SchemaElement> dimensions = new LinkedHashSet();
private SchemaElement entity;
@@ -42,12 +44,18 @@ public class SemanticParseInfo {
private SqlInfo sqlInfo = new SqlInfo();
private QueryType queryType = QueryType.OTHER;
public Long getModelId() {
return model != null ? model.getId() : 0L;
public String getModelClusterKey() {
if (model == null) {
return "";
}
return model.getKey();
}
public String getModelName() {
return model != null ? model.getName() : "null";
if (model == null) {
return "";
}
return model.getName();
}
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
@@ -78,4 +86,26 @@ public class SemanticParseInfo {
return metrics;
}
private Map<Long, Integer> getModelElementCountMap() {
Map<Long, Integer> elementCountMap = new HashMap<>();
elementMatches.forEach(element -> {
int count = elementCountMap.getOrDefault(element.getElement().getModel(), 0);
elementCountMap.put(element.getElement().getModel(), count + 1);
});
return elementCountMap;
}
public Long getModelId() {
Map<Long, Integer> elementCountMap = getModelElementCountMap();
Long modelId = -1L;
int maxCnt = 0;
for (Long model : elementCountMap.keySet()) {
if (elementCountMap.get(model) > maxCnt) {
maxCnt = elementCountMap.get(model);
modelId = model;
}
}
return modelId;
}
}

View File

@@ -1,9 +1,14 @@
package com.tencent.supersonic.chat.api.pojo;
import org.springframework.util.CollectionUtils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
public class SemanticSchema implements Serializable {
@@ -18,6 +23,64 @@ public class SemanticSchema implements Serializable {
modelSchemaList.add(schema);
}
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
Optional<SchemaElement> element = Optional.empty();
switch (elementType) {
case ENTITY:
element = getElementsById(elementID, getEntities());
break;
case MODEL:
element = getElementsById(elementID, getModels());
break;
case METRIC:
element = getElementsById(elementID, getMetrics());
break;
case DIMENSION:
element = getElementsById(elementID, getDimensions());
break;
case VALUE:
element = getElementsById(elementID, getDimensionValues());
break;
default:
}
if (element.isPresent()) {
return element.get();
} else {
return null;
}
}
public SchemaElement getElementByName(SchemaElementType elementType, String name) {
Optional<SchemaElement> element = Optional.empty();
switch (elementType) {
case ENTITY:
element = getElementsByName(name, getEntities());
break;
case MODEL:
element = getElementsByName(name, getModels());
break;
case METRIC:
element = getElementsByName(name, getMetrics());
break;
case DIMENSION:
element = getElementsByName(name, getDimensions());
break;
case VALUE:
element = getElementsByName(name, getDimensionValues());
break;
default:
}
if (element.isPresent()) {
return element.get();
} else {
return null;
}
}
public Map<Long, String> getModelIdToName() {
return modelSchemaList.stream()
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
@@ -35,9 +98,21 @@ public class SemanticSchema implements Serializable {
return dimensions;
}
public List<SchemaElement> getDimensions(Long modelId) {
public List<SchemaElement> getDimensions(Set<Long> modelIds) {
List<SchemaElement> dimensions = getDimensions();
return getElementsByModelId(modelId, dimensions);
return getElementsByModelId(modelIds, dimensions);
}
public SchemaElement getDimensions(Long id) {
List<SchemaElement> dimensions = getDimensions();
Optional<SchemaElement> dimension = getElementsById(id, dimensions);
return dimension.orElse(null);
}
public List<SchemaElement> getTags() {
List<SchemaElement> tags = new ArrayList<>();
modelSchemaList.stream().forEach(d -> tags.addAll(d.getTags()));
return tags;
}
public List<SchemaElement> getMetrics() {
@@ -46,21 +121,9 @@ public class SemanticSchema implements Serializable {
return metrics;
}
public List<SchemaElement> getMetrics(Long modelId) {
public List<SchemaElement> getMetrics(Set<Long> modelIds) {
List<SchemaElement> metrics = getMetrics();
return getElementsByModelId(modelId, metrics);
}
private List<SchemaElement> getElementsByModelId(Long modelId, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.collect(Collectors.toList());
}
public List<SchemaElement> getModels() {
List<SchemaElement> models = new ArrayList<>();
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
return models;
return getElementsByModelId(modelIds, metrics);
}
public List<SchemaElement> getEntities() {
@@ -69,11 +132,43 @@ public class SemanticSchema implements Serializable {
return entities;
}
public Map<String, String> getBizNameToName(Long modelId) {
private List<SchemaElement> getElementsByModelId(Set<Long> modelIds, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()))
.collect(Collectors.toList());
}
private Optional<SchemaElement> getElementsById(Long id, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> id.equals(schemaElement.getId()))
.findFirst();
}
private Optional<SchemaElement> getElementsByName(String name, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> name.equals(schemaElement.getName()))
.findFirst();
}
public List<SchemaElement> getModels() {
List<SchemaElement> models = new ArrayList<>();
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
return models;
}
public Map<String, String> getBizNameToName(Set<Long> modelIds) {
List<SchemaElement> allElements = new ArrayList<>();
allElements.addAll(getDimensions(modelId));
allElements.addAll(getMetrics(modelId));
allElements.addAll(getDimensions(modelIds));
allElements.addAll(getMetrics(modelIds));
return allElements.stream()
.collect(Collectors.toMap(a -> a.getBizName(), a -> a.getName(), (k1, k2) -> k1));
.collect(Collectors.toMap(SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1));
}
public Map<Long, ModelSchema> getModelSchemaMap() {
if (CollectionUtils.isEmpty(modelSchemaList)) {
return new HashMap<>();
}
return modelSchemaList.stream().collect(Collectors.toMap(modelSchema
-> modelSchema.getModel().getModel(), modelSchema -> modelSchema));
}
}

View File

@@ -7,7 +7,7 @@ import lombok.Data;
public class QueryReq {
private String queryText;
private Integer chatId;
private Long modelId = 0L;
private Long modelId;
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;

View File

@@ -18,7 +18,7 @@ public class SolvedQueryReq {
private String queryText;
private Long modelId;
private String modelId;
private Integer agentId;

View File

@@ -1,13 +1,13 @@
package com.tencent.supersonic.chat.api.pojo.response;
import lombok.Data;
import java.io.Serializable;
import java.util.List;
import lombok.Data;
@Data
public class ModelInfo extends DataInfo implements Serializable {
private List<String> words;
private String primaryEntityName;
private String primaryEntityBizName;
private String primaryKey;
}

View File

@@ -12,6 +12,10 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
@@ -19,9 +23,6 @@ import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
/**
* basic semantic correction functionality, offering common methods and an
@@ -45,7 +46,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
protected Map<String, String> getFieldNameMap(Long modelId) {
protected Map<String, String> getFieldNameMap(Set<Long> modelIds) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
@@ -55,7 +56,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
// support fieldName and field alias
Map<String, String> result = dbAllFields.stream()
.filter(entry -> entry.getModel().equals(modelId))
.filter(entry -> modelIds.contains(entry.getModel()))
.flatMap(schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
@@ -103,9 +104,9 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
protected void addAggregateToMetric(SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
Long modelId = semanticParseInfo.getModel().getModel();
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
List<SchemaElement> metrics = getMetricElements(modelId);
List<SchemaElement> metrics = getMetricElements(modelIds);
Map<String, String> metricToAggregate = metrics.stream()
.map(schemaElement -> {
@@ -122,9 +123,9 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
}
protected List<SchemaElement> getMetricElements(Long modelId) {
protected List<SchemaElement> getMetricElements(Set<Long> modelIds) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
return semanticSchema.getMetrics(modelId);
return semanticSchema.getMetrics(modelIds);
}
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FromCorrector extends BaseSemanticCorrector {
@Override
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String modelName = semanticParseInfo.getModel().getName();
SqlParserReplaceHelper.replaceTable(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), modelName);
}
}

View File

@@ -9,12 +9,13 @@ import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
/**
* Perform SQL corrections on the "group by" section in S2SQL.
@@ -30,14 +31,14 @@ public class GroupByCorrector extends BaseSemanticCorrector {
}
private void addGroupByFields(SemanticParseInfo semanticParseInfo) {
Long modelId = semanticParseInfo.getModel().getModel();
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
//add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
//add alias field name
Set<String> dimensions = semanticSchema.getDimensions(modelId).stream()
Set<String> dimensions = semanticSchema.getDimensions(modelIds).stream()
.flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();

View File

@@ -10,13 +10,14 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import org.springframework.util.CollectionUtils;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Perform SQL corrections on the "Having" section in S2SQL.
*/
@@ -37,11 +38,11 @@ public class HavingCorrector extends BaseSemanticCorrector {
}
private void addHaving(SemanticParseInfo semanticParseInfo) {
Long modelId = semanticParseInfo.getModel().getModel();
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
Set<String> metrics = semanticSchema.getMetrics(modelIds).stream()
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
if (CollectionUtils.isEmpty(metrics)) {

View File

@@ -51,7 +51,7 @@ public class SchemaCorrector extends BaseSemanticCorrector {
}
private void correctFieldName(SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap = getFieldNameMap(semanticParseInfo.getModelId());
Map<String, String> fieldNameMap = getFieldNameMap(semanticParseInfo.getModel().getModelIds());
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql = SqlParserReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
sqlInfo.setCorrectS2SQL(sql);

View File

@@ -16,11 +16,6 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserReplaceHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
@@ -29,6 +24,13 @@ import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Perform SQL corrections on the "Where" section in S2SQL.
*/
@@ -102,10 +104,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
private void updateFieldValueByTechName(SemanticParseInfo semanticParseInfo) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Long modelId = semanticParseInfo.getModel().getId();
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.collect(Collectors.toList());
Set<Long> modelIds = semanticParseInfo.getModel().getModelIds();
List<SchemaElement> dimensions = semanticSchema.getDimensions(modelIds);
if (CollectionUtils.isEmpty(dimensions)) {
return;

View File

@@ -17,15 +17,16 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class EmbedLLMInterpreter implements LLMInterpreter {
public LLMResp query2sql(LLMReq llmReq, Long modelId) {
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);

View File

@@ -9,8 +9,6 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import java.net.URI;
import java.net.URL;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
@@ -19,14 +17,16 @@ import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.net.URL;
@Slf4j
public class HttpLLMInterpreter implements LLMInterpreter {
public LLMResp query2sql(LLMReq llmReq, Long modelId) {
public LLMResp query2sql(LLMReq llmReq, String modelClusterKey) {
long startTime = System.currentTimeMillis();
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
log.info("requestLLM request, modelId:{},llmReq:{}", modelClusterKey, llmReq);
try {
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);

View File

@@ -11,7 +11,7 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
public interface LLMInterpreter {
LLMResp query2sql(LLMReq llmReq, Long modelId);
LLMResp query2sql(LLMReq llmReq, String modelClusterKey);
FunctionResp requestFunction(FunctionReq functionReq);

View File

@@ -9,12 +9,13 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
/**
* A mapper capable of converting the VALUE of entity dimension values into ID types.
*/
@@ -33,7 +34,8 @@ public class EntityMapper extends BaseMapper {
if (entity == null || entity.getId() == null) {
continue;
}
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
.filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {

View File

@@ -0,0 +1,52 @@
package com.tencent.supersonic.chat.mapper;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.utils.ModelClusterBuilder;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
public class ModelClusterMapper implements SchemaMapper {
@Override
public void map(QueryContext queryContext) {
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
List<ModelCluster> modelClusters = buildModelClusterMatched(schemaMapInfo, semanticSchema);
Map<String, List<SchemaElementMatch>> modelClusterElementMatches = new HashMap<>();
for (ModelCluster modelCluster : modelClusters) {
for (Long modelId : schemaMapInfo.getMatchedModels()) {
if (modelCluster.getModelIds().contains(modelId)) {
modelClusterElementMatches.computeIfAbsent(modelCluster.getKey(), k -> new ArrayList<>())
.addAll(schemaMapInfo.getMatchedElements(modelId));
}
}
}
SchemaModelClusterMapInfo modelClusterMapInfo = new SchemaModelClusterMapInfo();
modelClusterMapInfo.setModelElementMatches(modelClusterElementMatches);
queryContext.setModelClusterMapInfo(modelClusterMapInfo);
}
private List<ModelCluster> buildModelClusterMatched(SchemaMapInfo schemaMapInfo,
SemanticSchema semanticSchema) {
Set<Long> matchedModels = schemaMapInfo.getMatchedModels();
List<ModelCluster> modelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
return modelClusters.stream().map(ModelCluster::getModelIds).peek(modelCluster -> {
modelCluster.removeIf(model -> !matchedModels.contains(model));
}).filter(modelCluster -> modelCluster.size() > 0).map(ModelCluster::build).collect(Collectors.toList());
}
}

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
@@ -17,13 +16,14 @@ import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
/**
* Query type parser, determine if the query is a metric query, an entity query,
@@ -58,11 +58,10 @@ public class QueryTypeParser implements SemanticParser {
//If all the fields in the SELECT statement are of tag type.
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = semanticService.getModelSchema(parseInfo.getModelId());
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
if (CollectionUtils.isNotEmpty(selectFields) && Objects.nonNull(modelSchema) && CollectionUtils.isNotEmpty(
modelSchema.getTags())) {
Set<String> tags = modelSchema.getTags().stream().map(schemaElement -> schemaElement.getName())
if (CollectionUtils.isNotEmpty(selectFields)) {
Set<String> tags = semanticSchema.getTags().stream().map(SchemaElement::getName)
.collect(Collectors.toSet());
if (tags.containsAll(selectFields)) {
return QueryType.TAG;
@@ -72,10 +71,10 @@ public class QueryTypeParser implements SemanticParser {
//2. metric queryType
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getS2SQL());
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> metrics = semanticSchema.getMetrics(parseInfo.getModelId());
List<SchemaElement> metrics = semanticSchema.getMetrics(parseInfo.getModel().getModelIds());
if (CollectionUtils.isNotEmpty(metrics)) {
Set<String> metricNameSet = metrics.stream().map(metric -> metric.getName()).collect(Collectors.toSet());
boolean containMetric = selectFields.stream().anyMatch(selectField -> metricNameSet.contains(selectField));
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
if (containMetric) {
return QueryType.METRIC;
}

View File

@@ -5,10 +5,8 @@ import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
@@ -21,18 +19,19 @@ import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.query.llm.interpret.MetricInterpretQuery;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class MetricInterpretParser implements SemanticParser {
@@ -81,9 +80,8 @@ public class MetricInterpretParser implements SemanticParser {
}
public Set<SchemaElement> getMetrics(List<Long> metricIds, Long modelId) {
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
ModelSchema modelSchema = semanticInterpreter.getModelSchema(modelId, true);
Set<SchemaElement> metrics = modelSchema.getMetrics();
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
List<SchemaElement> metrics = semanticService.getSemanticSchema().getMetrics();
return metrics.stream().filter(schemaElement -> metricIds.contains(schemaElement.getId()))
.collect(Collectors.toSet());
}
@@ -112,16 +110,13 @@ public class MetricInterpretParser implements SemanticParser {
private SemanticParseInfo buildSemanticParseInfo(Long modelId, QueryReq queryReq, Set<SchemaElement> metrics,
List<SchemaElementMatch> schemaElementMatches, String toolName) {
SchemaElement model = new SchemaElement();
model.setModel(modelId);
model.setId(modelId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setMetrics(metrics);
SchemaElement dimension = new SchemaElement();
dimension.setBizName(TimeDimensionEnum.DAY.getName());
semanticParseInfo.setDimensions(Sets.newHashSet(dimension));
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setModel(model);
semanticParseInfo.setModel(ModelCluster.build(Sets.newHashSet(modelId)));
semanticParseInfo.setScore(queryReq.getQueryText().length());
DateConf dateConf = new DateConf();
dateConf.setDateMode(DateConf.DateMode.RECENT);

View File

@@ -5,9 +5,11 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
import com.tencent.supersonic.common.pojo.ModelCluster;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
@@ -17,31 +19,28 @@ import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class HeuristicModelResolver implements ModelResolver {
protected static Long selectModelBySchemaElementMatchScore(Map<Long, SemanticQuery> modelQueryModes,
SchemaMapInfo schemaMap) {
protected static String selectModelBySchemaElementMatchScore(Map<String, SemanticQuery> modelQueryModes,
SchemaModelClusterMapInfo schemaMap) {
//model count priority
Long modelIdByModelCount = getModelIdByMatchModelScore(schemaMap);
String modelIdByModelCount = getModelIdByMatchModelScore(schemaMap);
if (Objects.nonNull(modelIdByModelCount)) {
log.info("selectModel by model count:{}", modelIdByModelCount);
return modelIdByModelCount;
}
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
Map<String, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
if (modelTypeMap.size() == 1) {
Long modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
String modelSelect = modelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
if (modelQueryModes.containsKey(modelSelect)) {
log.info("selectModel with only one Model [{}]", modelSelect);
return modelSelect;
}
} else {
Map.Entry<Long, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
Map.Entry<String, ModelMatchResult> maxModel = modelTypeMap.entrySet().stream()
.filter(entry -> modelQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> {
int difference = o2.getValue().getCount() - o1.getValue().getCount();
@@ -56,16 +55,16 @@ public class HeuristicModelResolver implements ModelResolver {
return maxModel.getKey();
}
}
return 0L;
return null;
}
private static Long getModelIdByMatchModelScore(SchemaMapInfo schemaMap) {
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
private static String getModelIdByMatchModelScore(SchemaModelClusterMapInfo schemaMap) {
Map<String, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
// calculate model match score, matched element gets 1.0 point, and inherit element gets 0.5 point
Map<Long, Double> modelIdToModelScore = new HashMap<>();
Map<String, Double> modelIdToModelScore = new HashMap<>();
if (Objects.nonNull(modelElementMatches)) {
for (Entry<Long, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
Long modelId = modelElementMatch.getKey();
for (Entry<String, List<SchemaElementMatch>> modelElementMatch : modelElementMatches.entrySet()) {
String modelId = modelElementMatch.getKey();
List<Double> modelMatchesScore = modelElementMatch.getValue().stream()
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
.filter(elementMatch -> SchemaElementType.MODEL.equals(elementMatch.getElement().getType()))
@@ -77,7 +76,7 @@ public class HeuristicModelResolver implements ModelResolver {
modelIdToModelScore.put(modelId, score);
}
}
Entry<Long, Double> maxModelScore = modelIdToModelScore.entrySet().stream()
Entry<String, Double> maxModelScore = modelIdToModelScore.entrySet().stream()
.max(Comparator.comparingDouble(o -> o.getValue())).orElse(null);
log.info("maxModelCount:{},modelIdToModelCount:{}", maxModelScore, modelIdToModelScore);
if (Objects.nonNull(maxModelScore)) {
@@ -87,64 +86,10 @@ public class HeuristicModelResolver implements ModelResolver {
return null;
}
/**
* to check can switch Model if context exit Model
*
* @return false will use context Model, true will use other Model , maybe include context Model
*/
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> modelQueryModes, SchemaMapInfo schemaMap,
ChatContext chatCtx, QueryReq searchCtx,
Long modelId, Set<Long> restrictiveModels) {
if (!Objects.nonNull(modelId) || modelId <= 0) {
return true;
}
// except content Model, calculate the number of types for each Model, if numbers<=1 will not switch
Map<Long, ModelMatchResult> modelTypeMap = getModelTypeMap(schemaMap);
log.info("isAllowSwitch ModelTypeMap [{}]", modelTypeMap);
long otherModelTypeNumBigOneCount = modelTypeMap.entrySet().stream()
.filter(entry -> modelQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(modelId))
.filter(entry -> entry.getValue().getCount() > 1).count();
if (otherModelTypeNumBigOneCount >= 1) {
return true;
}
// if query text only contain time , will not switch
if (!CollectionUtils.isEmpty(modelQueryModes.values())) {
for (SemanticQuery semanticQuery : modelQueryModes.values()) {
if (semanticQuery == null) {
continue;
}
SemanticParseInfo semanticParseInfo = semanticQuery.getParseInfo();
if (semanticParseInfo == null) {
continue;
}
if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) {
if (semanticParseInfo.getDateInfo().getDetectWord() != null) {
if (semanticParseInfo.getDateInfo().getDetectWord()
.equalsIgnoreCase(searchCtx.getQueryText())) {
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
semanticParseInfo.getDateInfo());
return false;
}
}
}
}
}
if (CollectionUtils.isNotEmpty(restrictiveModels) && !restrictiveModels.contains(modelId)) {
return true;
}
// if context Model not in schemaMap , will switch
if (schemaMap.getMatchedElements(modelId) == null || schemaMap.getMatchedElements(modelId).size() <= 0) {
log.info("modelId not in schemaMap ");
return true;
}
// other will not switch
return false;
}
public static Map<Long, ModelMatchResult> getModelTypeMap(SchemaMapInfo schemaMap) {
Map<Long, ModelMatchResult> modelCount = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
public static Map<String, ModelMatchResult> getModelTypeMap(SchemaModelClusterMapInfo schemaMap) {
Map<String, ModelMatchResult> modelCount = new HashMap<>();
for (Map.Entry<String, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!modelCount.containsKey(entry.getKey())) {
@@ -170,65 +115,34 @@ public class HeuristicModelResolver implements ModelResolver {
}
public Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
public String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels) {
SchemaModelClusterMapInfo mapInfo = queryContext.getModelClusterMapInfo();
Set<String> matchedModelClusters = mapInfo.getElementMatchesByModelIds(restrictiveModels).keySet();
Long modelId = queryContext.getRequest().getModelId();
if (Objects.nonNull(modelId) && modelId > 0) {
if (CollectionUtils.isEmpty(restrictiveModels)) {
return modelId;
}
if (restrictiveModels.contains(modelId)) {
return modelId;
} else {
return null;
if (CollectionUtils.isEmpty(restrictiveModels) || restrictiveModels.contains(modelId)) {
return getModelClusterByModelId(modelId, matchedModelClusters);
}
return null;
}
SchemaMapInfo mapInfo = queryContext.getMapInfo();
Set<Long> matchedModels = mapInfo.getMatchedModels();
if (CollectionUtils.isNotEmpty(restrictiveModels)) {
matchedModels = matchedModels.stream()
.filter(restrictiveModels::contains)
.collect(Collectors.toSet());
}
Map<Long, SemanticQuery> modelQueryModes = new HashMap<>();
for (Long matchedModel : matchedModels) {
Map<String, SemanticQuery> modelQueryModes = new HashMap<>();
for (String matchedModel : matchedModelClusters) {
modelQueryModes.put(matchedModel, null);
}
if (modelQueryModes.size() == 1) {
return modelQueryModes.keySet().stream().findFirst().get();
}
return resolve(modelQueryModes, queryContext, chatCtx,
queryContext.getMapInfo(), restrictiveModels);
return selectModelBySchemaElementMatchScore(modelQueryModes, mapInfo);
}
public Long resolve(Map<Long, SemanticQuery> modelQueryModes, QueryContext queryContext,
ChatContext chatCtx, SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
Long selectModel = selectModel(modelQueryModes, queryContext.getRequest(),
chatCtx, schemaMap, restrictiveModels);
if (selectModel > 0) {
log.info("selectModel {} ", selectModel);
return selectModel;
}
// get the max SchemaElementType match score
return selectModelBySchemaElementMatchScore(modelQueryModes, schemaMap);
}
public Long selectModel(Map<Long, SemanticQuery> modelQueryModes, QueryReq queryContext,
ChatContext chatCtx,
SchemaMapInfo schemaMap, Set<Long> restrictiveModels) {
// if QueryContext has modelId and in ModelQueryModes
if (modelQueryModes.containsKey(queryContext.getModelId())) {
log.info("selectModel from QueryContext [{}]", queryContext.getModelId());
return queryContext.getModelId();
}
// if ChatContext has modelId and in ModelQueryModes
if (chatCtx.getParseInfo().getModelId() > 0) {
Long modelId = chatCtx.getParseInfo().getModelId();
if (!isAllowSwitch(modelQueryModes, schemaMap, chatCtx, queryContext, modelId, restrictiveModels)) {
log.info("selectModel from ChatContext [{}]", modelId);
return modelId;
private String getModelClusterByModelId(Long modelId, Set<String> modelClusterKeySet) {
for (String modelClusterKey : modelClusterKeySet) {
if (ModelCluster.build(modelClusterKey).getModelIds().contains(modelId)) {
return modelClusterKey;
}
}
// default 0
return 0L;
return null;
}
}

View File

@@ -12,21 +12,28 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.config.LLMParserConfig;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.llm.LLMInterpreter;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.llm.LLMInterpreter;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
@@ -35,12 +42,6 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Slf4j
@Service
@@ -72,18 +73,18 @@ public class LLMRequestService {
return false;
}
public Long getModelId(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx, Integer agentId) {
Set<Long> distinctModelIds = agentService.getModelIds(agentId, AgentToolType.LLM_S2SQL);
if (agentService.containsAllModel(distinctModelIds)) {
distinctModelIds = new HashSet<>();
}
ModelResolver modelResolver = ComponentFactory.getModelResolver();
Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
log.info("resolve modelId:{},llmParser Models:{}", modelId, distinctModelIds);
return modelId;
String modelCluster = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds);
log.info("resolve modelId:{},llmParser Models:{}", modelCluster, distinctModelIds);
return ModelCluster.build(modelCluster);
}
public CommonAgentTool getParserTool(QueryReq request, Long modelId) {
public CommonAgentTool getParserTool(QueryReq request, Set<Long> modelIdSet) {
List<CommonAgentTool> commonAgentTools = agentService.getParserTools(request.getAgentId(),
AgentToolType.LLM_S2SQL);
Optional<CommonAgentTool> llmParserTool = commonAgentTools.stream()
@@ -92,31 +93,36 @@ public class LLMRequestService {
if (agentService.containsAllModel(new HashSet<>(modelIds))) {
return true;
}
return modelIds.contains(modelId);
for (Long modelId : modelIdSet) {
if (modelIds.contains(modelId)) {
return true;
}
}
return false;
})
.findFirst();
return llmParserTool.orElse(null);
}
public LLMReq getLlmReq(QueryContext queryCtx, Long modelId, List<ElementValue> linkingValues) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
public LLMReq getLlmReq(QueryContext queryCtx, SemanticSchema semanticSchema,
ModelCluster modelCluster, List<ElementValue> linkingValues) {
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
String queryText = queryCtx.getRequest().getQueryText();
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);
Long firstModelId = modelCluster.getFirstModel();
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
filterCondition.setTableName(modelIdToName.get(modelId));
filterCondition.setTableName(modelIdToName.get(firstModelId));
llmReq.setFilterCondition(filterCondition);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmSchema.setModelName(modelIdToName.get(modelId));
llmSchema.setDomainName(modelIdToName.get(modelId));
llmSchema.setModelName(modelIdToName.get(firstModelId));
llmSchema.setDomainName(modelIdToName.get(firstModelId));
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, llmParserConfig);
List<String> fieldNameList = getFieldNameList(queryCtx, modelCluster, llmParserConfig);
String priorExts = getPriorExts(modelId, fieldNameList);
String priorExts = getPriorExts(modelCluster.getModelIds(), fieldNameList);
llmReq.setPriorExts(priorExts);
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
@@ -129,7 +135,7 @@ public class LLMRequestService {
}
llmReq.setLinking(linking);
String currentDate = S2SQLDateHelper.getReferenceDate(modelId);
String currentDate = S2SQLDateHelper.getReferenceDate(firstModelId);
if (StringUtils.isEmpty(currentDate)) {
currentDate = DateUtils.getBeforeDate(0);
}
@@ -137,24 +143,25 @@ public class LLMRequestService {
return llmReq;
}
public LLMResp requestLLM(LLMReq llmReq, Long modelId) {
return llmInterpreter.query2sql(llmReq, modelId);
public LLMResp requestLLM(LLMReq llmReq, String modelClusterKey) {
return llmInterpreter.query2sql(llmReq, modelClusterKey);
}
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) {
protected List<String> getFieldNameList(QueryContext queryCtx, ModelCluster modelCluster,
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(modelId, llmParserConfig);
Set<String> results = getTopNFieldNames(modelCluster, llmParserConfig);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelCluster);
results.addAll(fieldNameList);
return new ArrayList<>(results);
}
private String getPriorExts(Long modelId, List<String> fieldNameList) {
private String getPriorExts(Set<Long> modelIds, List<String> fieldNameList) {
StringBuilder extraInfoSb = new StringBuilder();
List<ModelSchemaResp> modelSchemaResps = semanticInterpreter.fetchModelSchema(
Collections.singletonList(modelId), true);
new ArrayList<>(modelIds), true);
if (!CollectionUtils.isEmpty(modelSchemaResps)) {
ModelSchemaResp modelSchemaResp = modelSchemaResps.get(0);
@@ -187,10 +194,11 @@ public class LLMRequestService {
}
protected List<ElementValue> getValueList(QueryContext queryCtx, Long modelId) {
Map<Long, String> itemIdToName = getItemIdToName(modelId);
protected List<ElementValue> getValueList(QueryContext queryCtx, ModelCluster modelCluster) {
Map<Long, String> itemIdToName = getItemIdToName(modelCluster);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
.getMatchedElements(modelCluster.getKey());
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
}
@@ -210,22 +218,22 @@ public class LLMRequestService {
return new ArrayList<>(valueMatches);
}
protected Map<Long, String> getItemIdToName(Long modelId) {
protected Map<Long, String> getItemIdToName(ModelCluster modelCluster) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
return semanticSchema.getDimensions(modelId).stream()
return semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}
private Set<String> getTopNFieldNames(Long modelId, LLMParserConfig llmParserConfig) {
private Set<String> getTopNFieldNames(ModelCluster modelCluster, LLMParserConfig llmParserConfig) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
Set<String> results = semanticSchema.getDimensions(modelId).stream()
Set<String> results = semanticSchema.getDimensions(modelCluster.getModelIds()).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
Set<String> metrics = semanticSchema.getMetrics(modelCluster.getModelIds()).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
@@ -236,9 +244,10 @@ public class LLMRequestService {
}
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId) {
Map<Long, String> itemIdToName = getItemIdToName(modelId);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, ModelCluster modelCluster) {
Map<Long, String> itemIdToName = getItemIdToName(modelCluster);
List<SchemaElementMatch> matchedElements = queryCtx.getModelClusterMapInfo()
.getMatchedElements(modelCluster.getKey());
if (CollectionUtils.isEmpty(matchedElements)) {
return new HashSet<>();
}

View File

@@ -2,24 +2,21 @@ package com.tencent.supersonic.chat.parser.llm.s2sql;
import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.query.llm.s2sql.S2SQLQuery;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserEqualHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
@Slf4j
@Service
public class LLMResponseService {
@@ -30,9 +27,10 @@ public class LLMResponseService {
}
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(S2SQLQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
Long modelId = parseResult.getModelId();
parseInfo.setModel(parseResult.getModelCluster());
CommonAgentTool commonAgentTool = parseResult.getCommonAgentTool();
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
parseInfo.getElementMatches().addAll(queryCtx.getModelClusterMapInfo()
.getMatchedElements(parseInfo.getModelClusterKey()));
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, parseResult);
@@ -43,15 +41,7 @@ public class LLMResponseService {
parseInfo.setScore(queryCtx.getRequest().getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());
parseInfo.getSqlInfo().setS2SQL(s2SQL);
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
SchemaElement model = new SchemaElement();
model.setModel(modelId);
model.setId(modelId);
model.setName(modelIdToName.get(modelId));
parseInfo.setModel(model);
parseInfo.setModel(parseResult.getModelCluster());
queryCtx.getCandidateQueries().add(semanticQuery);
return parseInfo;
}

View File

@@ -4,16 +4,21 @@ import com.tencent.supersonic.chat.agent.tool.CommonAgentTool;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
@Slf4j
public class LLMS2SQLParser implements SemanticParser {
@@ -22,36 +27,39 @@ public class LLMS2SQLParser implements SemanticParser {
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
QueryReq request = queryCtx.getRequest();
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
//1.determine whether to skip this parser.
if (requestService.check(queryCtx)) {
return;
}
try {
//2.get modelId from queryCtx and chatCtx.
Long modelId = requestService.getModelId(queryCtx, chatCtx, request.getAgentId());
if (Objects.isNull(modelId) || modelId <= 0) {
ModelCluster modelCluster = requestService.getModelCluster(queryCtx, chatCtx, request.getAgentId());
if (StringUtils.isBlank(modelCluster.getKey())) {
return;
}
//3.get agent tool and determine whether to skip this parser.
CommonAgentTool commonAgentTool = requestService.getParserTool(request, modelId);
CommonAgentTool commonAgentTool = requestService.getParserTool(request, modelCluster.getModelIds());
if (Objects.isNull(commonAgentTool)) {
log.info("no tool in this agent, skip {}", LLMS2SQLParser.class);
return;
}
//4.construct a request, call the API for the large model, and retrieve the results.
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, modelId);
LLMReq llmReq = requestService.getLlmReq(queryCtx, modelId, linkingValues);
LLMResp llmResp = requestService.requestLLM(llmReq, modelId);
List<ElementValue> linkingValues = requestService.getValueList(queryCtx, modelCluster);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
LLMReq llmReq = requestService.getLlmReq(queryCtx, semanticSchema, modelCluster, linkingValues);
LLMResp llmResp = requestService.requestLLM(llmReq, modelCluster.getKey());
if (Objects.isNull(llmResp)) {
return;
}
//5. deduplicate the SQL result list and build parserInfo
modelCluster.buildName(semanticSchema.getModelIdToName());
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
Map<String, Double> deduplicationSqlWeight = responseService.getDeduplicationSqlWeight(llmResp);
ParseResult parseResult = ParseResult.builder()
.request(request)
.modelId(modelId)
.modelCluster(modelCluster)
.commonAgentTool(commonAgentTool)
.llmReq(llmReq)
.llmResp(llmResp)

View File

@@ -7,6 +7,6 @@ import java.util.Set;
public interface ModelResolver {
Long resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
String resolve(QueryContext queryContext, ChatContext chatCtx, Set<Long> restrictiveModels);
}

View File

@@ -5,19 +5,21 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.llm.s2sql.LLMResp;
import java.util.List;
import com.tencent.supersonic.common.pojo.ModelCluster;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.List;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class ParseResult {
private Long modelId;
private ModelCluster modelCluster;
private LLMReq llmReq;

View File

@@ -7,11 +7,10 @@ import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
@@ -19,8 +18,10 @@ import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -58,8 +59,10 @@ public abstract class PluginParser implements SemanticParser {
}
for (Long modelId : modelIds) {
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin, queryContext.getRequest(),
queryContext.getMapInfo().getMatchedElements(modelId), pluginRecallResult.getDistance());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin,
queryContext.getRequest(),
queryContext.getModelClusterMapInfo().getMatchedElements(modelId),
pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
semanticParseInfo.setScore(pluginRecallResult.getScore());
pluginQuery.setParseInfo(semanticParseInfo);
@@ -79,12 +82,9 @@ public abstract class PluginParser implements SemanticParser {
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
}
SchemaElement model = new SchemaElement();
model.setModel(modelId);
model.setId(modelId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setModel(model);
semanticParseInfo.setModel(ModelCluster.build(Sets.newHashSet(modelId)));
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);

View File

@@ -13,6 +13,8 @@ import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
@@ -47,7 +49,11 @@ public class AgentCheckParser implements SemanticParser {
if (CollectionUtils.isEmpty(tool.getModelIds())) {
return true;
}
if (tool.isContainsAllModel() || tool.getModelIds().contains(query.getParseInfo().getModelId())) {
if (tool.isContainsAllModel()) {
return false;
}
if (new HashSet<>(tool.getModelIds())
.containsAll(query.getParseInfo().getModel().getModelIds())) {
return false;
}
}

View File

@@ -1,32 +1,41 @@
package com.tencent.supersonic.chat.parser.rule;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricTagQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricTagQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ModelClusterBuilder;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.TAG;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
/**
* ContextInheritParser tries to inherit certain schema elements from context
@@ -41,20 +50,22 @@ public class ContextInheritParser implements SemanticParser {
new AbstractMap.SimpleEntry<>(DIMENSION, Arrays.asList(DIMENSION, VALUE)),
new AbstractMap.SimpleEntry<>(VALUE, Arrays.asList(VALUE, DIMENSION)),
new AbstractMap.SimpleEntry<>(ENTITY, Arrays.asList(ENTITY)),
new AbstractMap.SimpleEntry<>(TAG, Arrays.asList(TAG)),
new AbstractMap.SimpleEntry<>(MODEL, Arrays.asList(MODEL)),
new AbstractMap.SimpleEntry<>(ID, Arrays.asList(ID))
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
if (!shouldInherit(queryContext, chatContext)) {
if (!shouldInherit(queryContext)) {
return;
}
Long modelId = chatContext.getParseInfo().getModelId();
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo()
.getMatchedElements(modelId);
ModelCluster modelCluster = getMatchedModelCluster(queryContext, chatContext);
if (modelCluster == null) {
return;
}
List<SchemaElementMatch> elementMatches = queryContext.getModelClusterMapInfo()
.getMatchedElements(modelCluster.getKey());
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
SchemaElementType matchType = match.getElement().getType();
@@ -69,18 +80,18 @@ public class ContextInheritParser implements SemanticParser {
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(modelId, queryContext, chatContext);
if (existSameQuery(query.getParseInfo().getModelId(), query.getQueryMode(), queryContext)) {
query.fillParseInfo(chatContext);
if (existSameQuery(query.getParseInfo().getModelClusterKey(), query.getQueryMode(), queryContext)) {
continue;
}
queryContext.getCandidateQueries().add(query);
}
}
private boolean existSameQuery(Long modelId, String queryMode, QueryContext queryContext) {
private boolean existSameQuery(String modelClusterKey, String queryMode, QueryContext queryContext) {
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
if (semanticQuery.getQueryMode().equals(queryMode)
&& semanticQuery.getParseInfo().getModelId().equals(modelId)) {
&& semanticQuery.getParseInfo().getModelClusterKey().equals(modelClusterKey)) {
return true;
}
}
@@ -101,23 +112,34 @@ public class ContextInheritParser implements SemanticParser {
});
}
protected boolean shouldInherit(QueryContext queryContext, ChatContext chatContext) {
Long contextModelId = chatContext.getParseInfo().getModelId();
// if map info doesn't contain the same Model of the context,
// no inheritance could be done
if (queryContext.getMapInfo().getMatchedElements(contextModelId) == null) {
return false;
}
protected boolean shouldInherit(QueryContext queryContext) {
// if candidates only have MetricModel mode, count in context
List<SemanticQuery> metricModelQueries = queryContext.getCandidateQueries().stream()
.filter(query -> query instanceof MetricModelQuery).collect(
Collectors.toList());
if (metricModelQueries.size() == queryContext.getCandidateQueries().size()) {
return true;
} else {
return queryContext.getCandidateQueries().size() == 0;
return metricModelQueries.size() == queryContext.getCandidateQueries().size();
}
protected ModelCluster getMatchedModelCluster(QueryContext queryContext, ChatContext chatContext) {
String contextModelClusterKey = chatContext.getParseInfo().getModelClusterKey();
if (StringUtils.isBlank(contextModelClusterKey)) {
return null;
}
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
List<ModelCluster> allModelClusters = ModelClusterBuilder.buildModelClusters(semanticSchema);
Set<String> queryModelClusters = queryContext.getModelClusterMapInfo().getMatchedModelClusters();
ModelCluster contextModelCluster = ModelCluster.build(contextModelClusterKey);
for (String cluster : queryModelClusters) {
ModelCluster queryModelCluster = ModelCluster.build(cluster);
for (ModelCluster modelCluster : allModelClusters) {
if (modelCluster.getModelIds().containsAll(contextModelCluster.getModelIds())
&& modelCluster.getModelIds().containsAll(queryModelCluster.getModelIds())) {
return queryModelCluster;
}
}
}
return null;
}
}

View File

@@ -1,16 +1,15 @@
package com.tencent.supersonic.chat.parser.rule;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaModelClusterMapInfo;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
/**
* QueryModeParser resolves a specific query mode according to co-appearance
* of certain schema element types.
@@ -20,13 +19,13 @@ public class QueryModeParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
SchemaMapInfo mapInfo = queryContext.getMapInfo();
SchemaModelClusterMapInfo modelClusterMapInfo = queryContext.getModelClusterMapInfo();
// iterate all schemaElementMatches to resolve query mode
for (Long modelId : mapInfo.getMatchedModels()) {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(modelId);
for (String modelClusterKey : modelClusterMapInfo.getMatchedModelClusters()) {
List<SchemaElementMatch> elementMatches = modelClusterMapInfo.getMatchedElements(modelClusterKey);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(modelId, queryContext, chatContext);
query.fillParseInfo(chatContext);
queryContext.getCandidateQueries().add(query);
}
}

View File

@@ -3,24 +3,23 @@ package com.tencent.supersonic.chat.postprocessor;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserRemoveHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@@ -34,17 +33,15 @@ public class MetricCheckPostProcessor implements PostProcessor {
@Override
public void process(QueryContext queryContext) {
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
Map<Long, ModelSchema> modelSchemaMap = new HashMap<>();
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
for (SemanticQuery semanticQuery : semanticQueries) {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
if (!QueryType.METRIC.equals(parseInfo.getQueryType())) {
continue;
}
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
ModelSchema modelSchema = schemaService.getModelSchema(parseInfo.getModelId());
String processedSql = processCorrectSql(parseInfo.getSqlInfo().getCorrectS2SQL(), modelSchema);
parseInfo.getSqlInfo().setCorrectS2SQL(processedSql);
modelSchemaMap.put(modelSchema.getModel().getModel(), modelSchema);
String correctSqlProcessed = processCorrectSql(parseInfo, semanticSchema);
parseInfo.getSqlInfo().setCorrectS2SQL(correctSqlProcessed);
}
semanticQueries.removeIf(semanticQuery -> {
if (!QueryType.METRIC.equals(semanticQuery.getParseInfo().getQueryType())) {
@@ -54,14 +51,14 @@ public class MetricCheckPostProcessor implements PostProcessor {
if (StringUtils.isBlank(correctSql)) {
return false;
}
return !checkHasMetric(correctSql, modelSchemaMap.get(semanticQuery.getParseInfo().getModelId()));
return !checkHasMetric(correctSql, semanticSchema);
});
}
public String processCorrectSql(String correctSql, ModelSchema modelSchema) {
public String processCorrectSql(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
String correctSql = parseInfo.getSqlInfo().getCorrectS2SQL();
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(correctSql);
List<String> metricFields = SqlParserSelectHelper.getAggregateFields(correctSql)
.stream().filter(metricField -> !metricField.equals("*")).collect(Collectors.toList());
List<String> metricFields = SqlParserSelectHelper.getAggregateFields(correctSql);
List<String> whereFields = SqlParserSelectHelper.getWhereFields(correctSql);
List<String> dimensionFields = getDimensionFields(groupByFields, whereFields);
if (CollectionUtils.isEmpty(metricFields) || StringUtils.isBlank(correctSql)) {
@@ -71,35 +68,33 @@ public class MetricCheckPostProcessor implements PostProcessor {
Set<String> groupByToRemove = Sets.newHashSet();
Set<String> whereFieldsToRemove = Sets.newHashSet();
for (String metricName : metricFields) {
SchemaElement metricElement = modelSchema.getElement(SchemaElementType.METRIC, metricName);
SchemaElement metricElement = semanticSchema.getElementByName(SchemaElementType.METRIC, metricName);
if (metricElement == null) {
metricToRemove.add(metricName);
}
if (!checkNecessaryDimension(metricElement, modelSchema, dimensionFields)) {
if (!checkNecessaryDimension(metricElement, semanticSchema, dimensionFields)) {
metricToRemove.add(metricName);
}
}
for (String dimensionName : whereFields) {
if (TimeDimensionEnum.getNameList().contains(dimensionName)
|| TimeDimensionEnum.getChNameList().contains(dimensionName)) {
if (TimeDimensionEnum.getNameList().contains(dimensionName)) {
continue;
}
if (!checkInModelSchema(dimensionName, SchemaElementType.DIMENSION, modelSchema)) {
if (!checkInModelSchema(dimensionName, SchemaElementType.DIMENSION, semanticSchema)) {
whereFieldsToRemove.add(dimensionName);
}
if (!checkDrillDownDimension(dimensionName, metricFields, modelSchema)) {
if (!checkDrillDownDimension(dimensionName, metricFields, semanticSchema)) {
whereFieldsToRemove.add(dimensionName);
}
}
for (String dimensionName : groupByFields) {
if (TimeDimensionEnum.getNameList().contains(dimensionName)
|| TimeDimensionEnum.getChNameList().contains(dimensionName)) {
if (TimeDimensionEnum.getNameList().contains(dimensionName)) {
continue;
}
if (!checkInModelSchema(dimensionName, SchemaElementType.DIMENSION, modelSchema)) {
if (!checkInModelSchema(dimensionName, SchemaElementType.DIMENSION, semanticSchema)) {
groupByToRemove.add(dimensionName);
}
if (!checkDrillDownDimension(dimensionName, metricFields, modelSchema)) {
if (!checkDrillDownDimension(dimensionName, metricFields, semanticSchema)) {
groupByToRemove.add(dimensionName);
}
}
@@ -111,9 +106,9 @@ public class MetricCheckPostProcessor implements PostProcessor {
* To check whether the dimension bound to the metric exists,
* eg: metric like UV is calculated in a certain dimension, it cannot be used on other dimensions.
*/
private boolean checkNecessaryDimension(SchemaElement metric, ModelSchema modelSchema,
private boolean checkNecessaryDimension(SchemaElement metric, SemanticSchema semanticSchema,
List<String> dimensionFields) {
List<String> necessaryDimensions = getNecessaryDimensionNames(metric, modelSchema);
List<String> necessaryDimensions = getNecessaryDimensionNames(metric, semanticSchema);
if (CollectionUtils.isEmpty(necessaryDimensions)) {
return true;
}
@@ -130,8 +125,8 @@ public class MetricCheckPostProcessor implements PostProcessor {
* eg: some descriptive dimensions are not suitable as drill-down dimensions
*/
private boolean checkDrillDownDimension(String dimensionName, List<String> metrics,
ModelSchema modelSchema) {
List<SchemaElement> metricElements = modelSchema.getMetrics().stream()
SemanticSchema semanticSchema) {
List<SchemaElement> metricElements = semanticSchema.getMetrics().stream()
.filter(schemaElement -> metrics.contains(schemaElement.getName()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(metricElements)) {
@@ -142,7 +137,7 @@ public class MetricCheckPostProcessor implements PostProcessor {
.map(schemaElement -> schemaElement.getRelateSchemaElements().stream()
.map(RelateSchemaElement::getDimensionId).collect(Collectors.toList()))
.flatMap(Collection::stream)
.map(id -> convertDimensionIdToName(id, modelSchema))
.map(id -> convertDimensionIdToName(id, semanticSchema))
.filter(Objects::nonNull)
.collect(Collectors.toList());
//if no metric has drill down dimension, return true
@@ -153,9 +148,9 @@ public class MetricCheckPostProcessor implements PostProcessor {
return relateDimensions.contains(dimensionName);
}
private List<String> getNecessaryDimensionNames(SchemaElement metric, ModelSchema modelSchema) {
private List<String> getNecessaryDimensionNames(SchemaElement metric, SemanticSchema semanticSchema) {
List<Long> necessaryDimensionIds = getNecessaryDimensions(metric);
return necessaryDimensionIds.stream().map(id -> convertDimensionIdToName(id, modelSchema))
return necessaryDimensionIds.stream().map(id -> convertDimensionIdToName(id, semanticSchema))
.filter(Objects::nonNull).collect(Collectors.toList());
}
@@ -183,23 +178,23 @@ public class MetricCheckPostProcessor implements PostProcessor {
return dimensionFields;
}
private String convertDimensionIdToName(Long id, ModelSchema modelSchema) {
SchemaElement schemaElement = modelSchema.getElement(SchemaElementType.DIMENSION, id);
private String convertDimensionIdToName(Long id, SemanticSchema semanticSchema) {
SchemaElement schemaElement = semanticSchema.getElement(SchemaElementType.DIMENSION, id);
if (schemaElement == null) {
return null;
}
return schemaElement.getName();
}
private boolean checkInModelSchema(String name, SchemaElementType type, ModelSchema modelSchema) {
SchemaElement schemaElement = modelSchema.getElement(type, name);
private boolean checkInModelSchema(String name, SchemaElementType type, SemanticSchema semanticSchema) {
SchemaElement schemaElement = semanticSchema.getElementByName(type, name);
return schemaElement != null;
}
private boolean checkHasMetric(String correctSql, ModelSchema modelSchema) {
private boolean checkHasMetric(String correctSql, SemanticSchema semanticSchema) {
List<String> selectFields = SqlParserSelectHelper.getSelectFields(correctSql);
List<String> aggFields = SqlParserSelectHelper.getAggregateFields(correctSql);
List<String> collect = modelSchema.getMetrics().stream()
List<String> collect = semanticSchema.getMetrics().stream()
.map(SchemaElement::getName).collect(Collectors.toList());
for (String field : selectFields) {
if (collect.contains(field)) {

View File

@@ -1,17 +1,40 @@
package com.tencent.supersonic.chat.postprocessor;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.service.ParseInfoService;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
/**
* update parse info from correct sql
*/
@Slf4j
public class ParseInfoUpdateProcessor implements PostProcessor {
@Override
@@ -20,10 +43,174 @@ public class ParseInfoUpdateProcessor implements PostProcessor {
if (CollectionUtils.isEmpty(candidateQueries)) {
return;
}
ParseInfoService parseInfoService = ContextUtils.getBean(ParseInfoService.class);
List<SemanticParseInfo> candidateParses = candidateQueries.stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
candidateParses.forEach(parseInfoService::updateParseInfo);
candidateParses.forEach(this::updateParseInfo);
}
public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
if (StringUtils.isBlank(correctS2SQL)) {
return;
}
// if S2SQL equals correctS2SQL, than not update the parseInfo.
if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
return;
}
List<FieldExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL);
//set dataInfo
try {
if (!org.apache.commons.collections.CollectionUtils.isEmpty(expressions)) {
DateConf dateInfo = getDateInfo(expressions);
if (dateInfo != null && parseInfo.getDateInfo() == null) {
parseInfo.setDateInfo(dateInfo);
}
}
} catch (Exception e) {
log.error("set dateInfo error :", e);
}
//set filter
try {
Map<String, SchemaElement> fieldNameToElement = getNameToElement(parseInfo.getModel().getModelIds());
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
}
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
Set<SchemaElement> metrics = getElements(parseInfo.getModel().getModelIds(),
allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
parseInfo.setQueryType(QueryType.METRIC);
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(
getElements(parseInfo.getModel().getModelIds(), groupByDimensions, semanticSchema.getDimensions()));
} else {
parseInfo.setQueryType(QueryType.TAG);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(
getElements(parseInfo.getModel().getModelIds(), selectDimensions, semanticSchema.getDimensions()));
}
}
private Set<SchemaElement> getElements(Set<Long> modelIds, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelIds.contains(schemaElement.getModel())
&& allFields.contains(schemaElement.getName())
).collect(Collectors.toSet());
}
private List<String> getFieldsExceptDate(List<String> allFields) {
if (org.springframework.util.CollectionUtils.isEmpty(allFields)) {
return new ArrayList<>();
}
return allFields.stream()
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
.collect(Collectors.toList());
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FieldExpression> fieldExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FieldExpression expression : fieldExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue());
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
if (Objects.isNull(schemaElement)) {
continue;
}
dimensionFilter.setName(schemaElement.getName());
dimensionFilter.setBizName(schemaElement.getBizName());
dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
dimensionFilter.setOperator(operatorEnum);
dimensionFilter.setFunction(expression.getFunction());
result.add(dimensionFilter);
}
return result;
}
private DateConf getDateInfo(List<FieldExpression> fieldExpressions) {
List<FieldExpression> dateExpressions = fieldExpressions.stream()
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (org.apache.commons.collections.CollectionUtils.isEmpty(dateExpressions)) {
return null;
}
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
FieldExpression firstExpression = dateExpressions.get(0);
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
return dateInfo;
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
}
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
}
}
return dateInfo;
}
private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
}
private boolean hasSecondDate(List<FieldExpression> dateExpressions) {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
protected Map<String, SchemaElement> getNameToElement(Set<Long> modelIds) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions(modelIds);
List<SchemaElement> metrics = semanticSchema.getMetrics(modelIds);
List<SchemaElement> allElements = Lists.newArrayList();
allElements.addAll(dimensions);
allElements.addAll(metrics);
//support alias
return allElements.stream()
.flatMap(schemaElement -> {
Set<Pair<String, SchemaElement>> result = new HashSet<>();
result.add(Pair.of(schemaElement.getName(), schemaElement));
List<String> aliasList = schemaElement.getAlias();
if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(Pair.of(alias, schemaElement));
}
}
return result.stream();
})
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
}
}

View File

@@ -20,15 +20,16 @@ import com.tencent.supersonic.semantic.api.model.response.ExplainResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@Slf4j
@ToString
@@ -49,7 +50,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
explainSqlReq = ExplainSqlReq.builder()
.queryTypeEnum(QueryTypeEnum.SQL)
.queryReq(QueryReqBuilder.buildS2SQLReq(
sqlInfo.getCorrectS2SQL(), parseInfo.getModelId()
sqlInfo.getCorrectS2SQL(), parseInfo.getModel().getModelIds()
))
.build();
} else {
@@ -86,7 +87,7 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
protected void convertBizNameToName(QueryStructReq queryStructReq) {
SchemaService schemaService = ContextUtils.getBean(SchemaService.class);
Map<String, String> bizNameToName = schemaService.getSemanticSchema()
.getBizNameToName(queryStructReq.getModelId());
.getBizNameToName(queryStructReq.getModelIdSet());
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
List<Order> orders = queryStructReq.getOrders();

View File

@@ -1,72 +0,0 @@
package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
@Component
public class QueryRanker {
@Value("${candidate.top.size:5}")
private int candidateTopSize;
public List<SemanticQuery> rank(List<SemanticQuery> candidateQueries) {
log.debug("pick before [{}]", candidateQueries);
if (CollectionUtils.isEmpty(candidateQueries)) {
return candidateQueries;
}
List<SemanticQuery> selectedQueries = new ArrayList<>();
if (candidateQueries.size() == 1) {
selectedQueries.addAll(candidateQueries);
} else {
selectedQueries = getTopCandidateQuery(candidateQueries);
}
generateParseInfoId(selectedQueries);
log.debug("pick after [{}]", selectedQueries);
return selectedQueries;
}
public List<SemanticQuery> getTopCandidateQuery(List<SemanticQuery> semanticQueries) {
return semanticQueries.stream()
.filter(query -> !checkFullyInherited(query))
.sorted((o1, o2) -> {
if (o1.getParseInfo().getScore() < o2.getParseInfo().getScore()) {
return 1;
} else if (o1.getParseInfo().getScore() > o2.getParseInfo().getScore()) {
return -1;
}
return 0;
}).limit(candidateTopSize)
.collect(Collectors.toList());
}
private void generateParseInfoId(List<SemanticQuery> semanticQueries) {
for (int i = 0; i < semanticQueries.size(); i++) {
SemanticQuery query = semanticQueries.get(i);
query.getParseInfo().setId(i + 1);
}
}
private boolean checkFullyInherited(SemanticQuery query) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!(query instanceof RuleSemanticQuery)) {
return false;
}
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
if (!match.isInherited()) {
return false;
}
}
return parseInfo.getDateInfo() == null || parseInfo.getDateInfo().isInherited();
}
}

View File

@@ -12,12 +12,13 @@ import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@Slf4j
@Component
@@ -40,7 +41,7 @@ public class S2SQLQuery extends LLMSemanticQuery {
long startTime = System.currentTimeMillis();
String querySql = parseInfo.getSqlInfo().getCorrectS2SQL();
QueryS2SQLReq queryS2SQLReq = QueryReqBuilder.buildS2SQLReq(querySql, parseInfo.getModelId());
QueryS2SQLReq queryS2SQLReq = QueryReqBuilder.buildS2SQLReq(querySql, parseInfo.getModel().getModelIds());
QueryResultWithSchemaResp queryResp = semanticInterpreter.queryByS2SQL(queryS2SQLReq, user);
log.info("queryByS2SQL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.query.plugin.webpage;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
@@ -17,17 +16,15 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.query.plugin.WebBaseResult;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
@Slf4j
@Component
@@ -53,9 +50,6 @@ public class WebPageQuery extends PluginSemanticQuery {
PluginParseResult.class);
WebPageResponse webPageResponse = buildResponse(pluginParseResult);
queryResult.setResponse(webPageResponse);
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = semanticService.getModelSchema(parseInfo.getModelId());
parseInfo.setModel(modelSchema.getModel());
queryResult.setQueryState(QueryState.SUCCESS);
return queryResult;
}
@@ -79,7 +73,8 @@ public class WebPageQuery extends PluginSemanticQuery {
List<ParamOption> paramOptions = Lists.newArrayList();
if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) {
for (ParamOption paramOption : webPage.getParamOptions()) {
if (paramOption.getModelId() != null && !paramOption.getModelId().equals(parseInfo.getModelId())) {
if (paramOption.getModelId() != null
&& !parseInfo.getModel().getModelIds().contains(paramOption.getModelId())) {
continue;
}
paramOptions.add(paramOption);

View File

@@ -4,12 +4,12 @@ package com.tencent.supersonic.chat.query.rule;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
@@ -19,21 +19,25 @@ import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@ToString
@@ -56,13 +60,13 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
initS2SqlByStruct();
}
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
public void fillParseInfo(ChatContext chatContext) {
parseInfo.setQueryMode(getQueryMode());
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
fillSchemaElement(parseInfo, modelSchema);
fillSchemaElement(parseInfo, semanticSchema);
fillScore(parseInfo);
fillDateConf(parseInfo, chatContext.getParseInfo());
}
@@ -101,9 +105,12 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
parseInfo.setScore(parseInfo.getScore() + totalScore);
}
private void fillSchemaElement(SemanticParseInfo parseInfo, ModelSchema modelSchema) {
parseInfo.setModel(modelSchema.getModel());
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
Set<Long> modelIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
.map(SchemaElement::getModel).collect(Collectors.toSet());
ModelCluster modelCluster = ModelCluster.build(modelIds);
modelCluster.buildName(semanticSchema.getModelIdToName());
parseInfo.setModel(modelCluster);
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
Map<Long, List<SchemaElementMatch>> id2Values = new HashMap<>();
@@ -112,7 +119,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
element.setOrder(1 - schemaMatch.getSimilarity());
switch (element.getType()) {
case ID:
SchemaElement entityElement = modelSchema.getElement(SchemaElementType.ENTITY, element.getId());
SchemaElement entityElement = semanticSchema.getElement(SchemaElementType.ENTITY, element.getId());
if (entityElement != null) {
if (id2Values.containsKey(element.getId())) {
id2Values.get(element.getId()).add(schemaMatch);
@@ -122,7 +129,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
}
break;
case VALUE:
SchemaElement dimElement = modelSchema.getElement(SchemaElementType.DIMENSION, element.getId());
SchemaElement dimElement = semanticSchema.getElement(SchemaElementType.DIMENSION, element.getId());
if (dimElement != null) {
if (dim2Values.containsKey(element.getId())) {
dim2Values.get(element.getId()).add(schemaMatch);
@@ -146,20 +153,20 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
if (!id2Values.isEmpty()) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : id2Values.entrySet()) {
addFilters(parseInfo, modelSchema, entry, SchemaElementType.ENTITY);
addFilters(parseInfo, semanticSchema, entry, SchemaElementType.ENTITY);
}
}
if (!dim2Values.isEmpty()) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dim2Values.entrySet()) {
addFilters(parseInfo, modelSchema, entry, SchemaElementType.DIMENSION);
addFilters(parseInfo, semanticSchema, entry, SchemaElementType.DIMENSION);
}
}
}
private void addFilters(SemanticParseInfo parseInfo, ModelSchema modelSchema,
Entry<Long, List<SchemaElementMatch>> entry, SchemaElementType dimension1) {
SchemaElement dimension = modelSchema.getElement(dimension1, entry.getKey());
private void addFilters(SemanticParseInfo parseInfo, SemanticSchema semanticSchema,
Entry<Long, List<SchemaElementMatch>> entry, SchemaElementType elementType) {
SchemaElement dimension = semanticSchema.getElement(elementType, entry.getKey());
if (entry.getValue().size() == 1) {
SchemaElementMatch schemaMatch = entry.getValue().get(0);
@@ -170,7 +177,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
dimensionFilter.setElementID(schemaMatch.getElement().getId());
parseInfo.getDimensionFilters().add(dimensionFilter);
parseInfo.setEntity(modelSchema.getEntity());
parseInfo.setEntity(semanticSchema.getElement(SchemaElementType.ENTITY, entry.getKey()));
} else {
QueryFilter dimensionFilter = new QueryFilter();
List<String> vals = new ArrayList<>();
@@ -189,7 +196,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
public QueryResult execute(User user) {
String queryMode = parseInfo.getQueryMode();
if (parseInfo.getModelId() < 0 || StringUtils.isEmpty(queryMode)
if (StringUtils.isBlank(parseInfo.getModelClusterKey()) || StringUtils.isEmpty(queryMode)
|| !QueryManager.containsRuleQuery(queryMode)) {
// reach here some error may happen
log.error("not find QueryMode");
@@ -230,7 +237,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
public QueryResult multiStructExecute(User user) {
String queryMode = parseInfo.getQueryMode();
if (parseInfo.getModelId() < 0 || StringUtils.isEmpty(queryMode)
if (StringUtils.isBlank(parseInfo.getModelClusterKey()) || StringUtils.isEmpty(queryMode)
|| !QueryManager.containsRuleQuery(queryMode)) {
// reach here some error may happen
log.error("not find QueryMode");

View File

@@ -1,8 +1,5 @@
package com.tencent.supersonic.chat.query.rule.metric;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
@@ -20,12 +17,17 @@ import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
@Slf4j
public abstract class MetricSemanticQuery extends RuleSemanticQuery {
@@ -82,8 +84,8 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
@Override
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
super.fillParseInfo(modelId, queryContext, chatContext);
public void fillParseInfo(ChatContext chatContext) {
super.fillParseInfo(chatContext);
parseInfo.setLimit(METRIC_MAX_RESULTS);
if (parseInfo.getDateInfo() == null) {

View File

@@ -25,7 +25,7 @@ import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNum
@Component
public class MetricTagQuery extends MetricSemanticQuery {
public static final String QUERY_MODE = "TAG_ENTITY";
public static final String QUERY_MODE = "METRIC_ENTITY";
public MetricTagQuery() {
super();

View File

@@ -1,12 +1,5 @@
package com.tencent.supersonic.chat.query.rule.metric;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.common.pojo.Constants.DESC_UPPER;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
@@ -20,6 +13,13 @@ import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.common.pojo.Constants.DESC_UPPER;
@Component
public class MetricTopNQuery extends MetricSemanticQuery {
@@ -50,8 +50,8 @@ public class MetricTopNQuery extends MetricSemanticQuery {
}
@Override
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
super.fillParseInfo(modelId, queryContext, chatContext);
public void fillParseInfo(ChatContext chatContext) {
super.fillParseInfo(chatContext);
parseInfo.setLimit(ORDERBY_MAX_RESULTS);
parseInfo.setScore(parseInfo.getScore() + 2.0);

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.query.rule.tag;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
@@ -12,16 +11,17 @@ import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.util.ContextUtils;
import org.apache.commons.collections.CollectionUtils;
import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Set;
import org.apache.commons.collections.CollectionUtils;
public abstract class TagListQuery extends TagSemanticQuery {
@Override
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
super.fillParseInfo(modelId, queryContext, chatContext);
public void fillParseInfo(ChatContext chatContext) {
super.fillParseInfo(chatContext);
this.addEntityDetailAndOrderByMetric(parseInfo);
}
@@ -29,13 +29,12 @@ public abstract class TagListQuery extends TagSemanticQuery {
Long modelId = parseInfo.getModelId();
if (Objects.nonNull(modelId) && modelId > 0L) {
ConfigService configService = ContextUtils.getBean(ConfigService.class);
ChatConfigRichResp chaConfigRichDesc = configService.getConfigRichInfo(parseInfo.getModelId());
ChatConfigRichResp chaConfigRichDesc = configService.getConfigRichInfo(modelId);
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
ModelSchema modelSchema = schemaService.getModelSchema(parseInfo.getModelId());
if (chaConfigRichDesc != null && chaConfigRichDesc.getChatDetailRichConfig() != null
&& Objects.nonNull(modelSchema) && Objects.nonNull(modelSchema.getEntity())) {
Set<SchemaElement> dimensions = new LinkedHashSet();
Set<SchemaElement> dimensions = new LinkedHashSet<>();
Set<SchemaElement> metrics = new LinkedHashSet();
Set<Order> orders = new LinkedHashSet();
ChatDefaultRichConfigResp chatDefaultConfig = chaConfigRichDesc
@@ -52,9 +51,7 @@ public abstract class TagListQuery extends TagSemanticQuery {
chatDefaultConfig.getDimensions().stream()
.forEach(dimension -> dimensions.add(dimension));
}
}
parseInfo.setDimensions(dimensions);
parseInfo.setMetrics(metrics);
parseInfo.setOrders(orders);

View File

@@ -1,12 +1,7 @@
package com.tencent.supersonic.chat.query.rule.tag;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
@@ -15,13 +10,19 @@ import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
@Slf4j
public abstract class TagSemanticQuery extends RuleSemanticQuery {
@@ -78,8 +79,8 @@ public abstract class TagSemanticQuery extends RuleSemanticQuery {
}
@Override
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
super.fillParseInfo(modelId, queryContext, chatContext);
public void fillParseInfo(ChatContext chatContext) {
super.fillParseInfo(chatContext);
parseInfo.setQueryType(QueryType.TAG);
parseInfo.setLimit(TAG_MAX_RESULTS);

View File

@@ -14,7 +14,7 @@ public class EntityInfoExecuteResponder implements ExecuteResponder {
@Override
public void fillResponse(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
if (semanticParseInfo == null || semanticParseInfo.getModelId() <= 0L) {
if (semanticParseInfo == null) {
return;
}
String queryMode = semanticParseInfo.getQueryMode();

View File

@@ -39,7 +39,7 @@ public class SimilarMetricExecuteResponder implements ExecuteResponder {
}
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
Map<String, String> filterCondition = new HashMap<>();
filterCondition.put("modelId", parseInfo.getModelId().toString());
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getModel().toString());
filterCondition.put("type", SchemaElementType.METRIC.name());
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
.filterCondition(filterCondition).queryEmbeddings(null).build();

View File

@@ -3,22 +3,80 @@ package com.tencent.supersonic.chat.responder.parse;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.query.QueryRanker;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
/**
* Rank queries by score.
*/
@Slf4j
public class QueryRankParseResponder implements ParseResponder {
private static final int candidateTopSize = 5;
@Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
QueryRanker queryRanker = ContextUtils.getBean(QueryRanker.class);
candidateQueries = queryRanker.rank(candidateQueries);
candidateQueries = rank(candidateQueries);
queryContext.setCandidateQueries(candidateQueries);
}
public List<SemanticQuery> rank(List<SemanticQuery> candidateQueries) {
log.debug("pick before [{}]", candidateQueries);
if (CollectionUtils.isEmpty(candidateQueries)) {
return candidateQueries;
}
List<SemanticQuery> selectedQueries = new ArrayList<>();
if (candidateQueries.size() == 1) {
selectedQueries.addAll(candidateQueries);
} else {
selectedQueries = getTopCandidateQuery(candidateQueries);
}
generateParseInfoId(selectedQueries);
log.debug("pick after [{}]", selectedQueries);
return selectedQueries;
}
public List<SemanticQuery> getTopCandidateQuery(List<SemanticQuery> semanticQueries) {
return semanticQueries.stream()
.filter(query -> !checkFullyInherited(query))
.sorted((o1, o2) -> {
if (o1.getParseInfo().getScore() < o2.getParseInfo().getScore()) {
return 1;
} else if (o1.getParseInfo().getScore() > o2.getParseInfo().getScore()) {
return -1;
}
return 0;
}).limit(candidateTopSize)
.collect(Collectors.toList());
}
private void generateParseInfoId(List<SemanticQuery> semanticQueries) {
for (int i = 0; i < semanticQueries.size(); i++) {
SemanticQuery query = semanticQueries.get(i);
query.getParseInfo().setId(i + 1);
}
}
private boolean checkFullyInherited(SemanticQuery query) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!(query instanceof RuleSemanticQuery)) {
return false;
}
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
if (!match.isInherited()) {
return false;
}
}
return parseInfo.getDateInfo() == null || parseInfo.getDateInfo().isInherited();
}
}

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
@@ -17,12 +18,6 @@ import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.chat.service.ConfigService;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
@@ -32,6 +27,10 @@ import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.List;
@RestController
@RequestMapping({"/api/chat/conf", "/openapi/chat/conf"})

View File

@@ -14,6 +14,7 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import java.util.List;
import java.util.Set;
public interface ChatService {
@@ -22,7 +23,7 @@ public interface ChatService {
* @param chatId
* @return
*/
Long getContextModel(Integer chatId);
Set<Long> getContextModel(Integer chatId);
ChatContext getOrCreateContext(int chatId);

View File

@@ -1,9 +0,0 @@
package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
public interface ParseInfoService {
void updateParseInfo(SemanticParseInfo parseInfo);
}

View File

@@ -1,26 +1,14 @@
package com.tencent.supersonic.chat.service;
import static com.tencent.supersonic.common.pojo.Constants.DAY;
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT_INT;
import static com.tencent.supersonic.common.pojo.Constants.MONTH;
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT_INT;
import static com.tencent.supersonic.common.pojo.Constants.TIMES_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.AggregateInfo;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
@@ -35,7 +23,9 @@ import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.ModelCluster;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.RatioOverType;
@@ -44,6 +34,13 @@ import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.text.DecimalFormat;
import java.time.DayOfWeek;
import java.time.LocalDate;
@@ -63,12 +60,16 @@ import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import static com.tencent.supersonic.common.pojo.Constants.DAY;
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT_INT;
import static com.tencent.supersonic.common.pojo.Constants.MONTH;
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT_INT;
import static com.tencent.supersonic.common.pojo.Constants.TIMES_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
@Service
@Slf4j
@@ -83,45 +84,38 @@ public class SemanticService {
private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
public ModelSchema getModelSchema(Long id) {
ModelSchema modelSchema = schemaService.getModelSchema(id);
if (!Objects.isNull(modelSchema) && !Objects.isNull(modelSchema.getModel())) {
ChatConfigResp chaConfigInfo =
configService.fetchConfigByModelId(modelSchema.getModel().getId());
// filter dimensions in blacklist
filterBlackDim(modelSchema, chaConfigInfo);
// filter metrics in blacklist
filterBlackMetric(modelSchema, chaConfigInfo);
}
public SemanticSchema getSemanticSchema() {
return schemaService.getSemanticSchema();
}
return modelSchema;
public ModelSchema getModelSchema(Long id) {
return schemaService.getModelSchema(id);
}
public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, User user) {
if (parseInfo != null && parseInfo.getModelId() > 0) {
EntityInfo entityInfo = getEntityInfo(parseInfo.getModelId());
if (parseInfo.getDimensionFilters().size() <= 0) {
if (parseInfo.getDimensionFilters().size() <= 0 || entityInfo.getModelInfo() == null) {
entityInfo.setMetrics(null);
entityInfo.setDimensions(null);
return entityInfo;
}
if (entityInfo.getModelInfo() != null && entityInfo.getModelInfo().getPrimaryEntityBizName() != null) {
String modelInfoPrimaryName = entityInfo.getModelInfo().getPrimaryEntityBizName();
String primaryKey = entityInfo.getModelInfo().getPrimaryKey();
if (StringUtils.isNotBlank(primaryKey)) {
String modelInfoId = "";
for (QueryFilter chatFilter : parseInfo.getDimensionFilters()) {
if (chatFilter != null && chatFilter.getBizName() != null && chatFilter.getBizName()
.equals(modelInfoPrimaryName)) {
.equals(primaryKey)) {
if (chatFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
modelInfoId = chatFilter.getValue().toString();
}
}
}
try {
setMainModel(entityInfo, parseInfo.getModelId(),
modelInfoId, user);
setMainModel(entityInfo, parseInfo, modelInfoId, user);
return entityInfo;
} catch (Exception e) {
log.error("setMainModel error {}", e);
log.error("setMainModel error", e);
}
}
}
@@ -152,8 +146,7 @@ public class SemanticService {
modelInfo.setWords(modelSchema.getModel().getAlias());
modelInfo.setBizName(modelSchema.getModel().getBizName());
if (Objects.nonNull(modelSchema.getEntity())) {
modelInfo.setPrimaryEntityName(modelSchema.getEntity().getName());
modelInfo.setPrimaryEntityBizName(modelSchema.getEntity().getBizName());
modelInfo.setPrimaryKey(modelSchema.getEntity().getBizName());
}
entityInfo.setModelInfo(modelInfo);
@@ -190,21 +183,14 @@ public class SemanticService {
return entityInfo;
}
public String getPrimaryEntityBizName(EntityInfo entityInfo) {
if (Objects.isNull(entityInfo) || Objects.isNull(entityInfo.getModelInfo())) {
return null;
}
return entityInfo.getModelInfo().getPrimaryEntityBizName();
}
public void setMainModel(EntityInfo modelInfo, Long model, String entity, User user) {
public void setMainModel(EntityInfo modelInfo, SemanticParseInfo parseInfo, String entity, User user) {
if (StringUtils.isEmpty(entity)) {
return;
}
List<String> entities = Collections.singletonList(entity);
QueryResultWithSchemaResp queryResultWithColumns = getQueryResultWithSchemaResp(modelInfo, model, entities,
QueryResultWithSchemaResp queryResultWithColumns = getQueryResultWithSchemaResp(modelInfo, parseInfo, entities,
user);
if (queryResultWithColumns != null) {
@@ -225,15 +211,15 @@ public class SemanticService {
}
}
public QueryResultWithSchemaResp getQueryResultWithSchemaResp(EntityInfo modelInfo, Long model,
public QueryResultWithSchemaResp getQueryResultWithSchemaResp(EntityInfo modelInfo, SemanticParseInfo parseInfo,
List<String> entities, User user) {
if (CollectionUtils.isEmpty(entities)) {
return null;
}
ModelSchema modelSchema = schemaService.getModelSchema(model);
ModelSchema modelSchema = schemaService.getModelSchema(parseInfo.getModelId());
modelInfo.setEntityId(entities.get(0));
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setModel(modelSchema.getModel());
semanticParseInfo.setModel(ModelCluster.build(Sets.newHashSet(parseInfo.getModelId())));
semanticParseInfo.setQueryType(QueryType.TAG);
semanticParseInfo.setMetrics(getMetrics(modelInfo));
semanticParseInfo.setDimensions(getDimensions(modelInfo));
@@ -314,54 +300,7 @@ public class SemanticService {
}
private String getEntityPrimaryName(EntityInfo modelInfo) {
return modelInfo.getModelInfo().getPrimaryEntityBizName();
}
private void filterBlackMetric(ModelSchema modelSchema, ChatConfigResp chaConfigInfo) {
ItemVisibility visibility = generateFinalVisibility(chaConfigInfo);
if (Objects.nonNull(chaConfigInfo) && Objects.nonNull(visibility)
&& !CollectionUtils.isEmpty(visibility.getBlackMetricIdList())
&& !CollectionUtils.isEmpty(modelSchema.getMetrics())) {
Set<SchemaElement> metric4Chat = modelSchema.getMetrics().stream()
.filter(metric -> !visibility.getBlackMetricIdList().contains(metric.getId()))
.collect(Collectors.toSet());
modelSchema.setMetrics(metric4Chat);
}
}
private void filterBlackDim(ModelSchema modelSchema, ChatConfigResp chatConfigInfo) {
ItemVisibility visibility = generateFinalVisibility(chatConfigInfo);
if (Objects.nonNull(chatConfigInfo) && Objects.nonNull(visibility)
&& !CollectionUtils.isEmpty(visibility.getBlackDimIdList())
&& !CollectionUtils.isEmpty(modelSchema.getDimensions())) {
Set<SchemaElement> dim4Chat = modelSchema.getDimensions().stream()
.filter(dim -> !visibility.getBlackDimIdList().contains(dim.getId()))
.collect(Collectors.toSet());
modelSchema.setDimensions(dim4Chat);
}
}
private ItemVisibility generateFinalVisibility(ChatConfigResp chatConfigInfo) {
ItemVisibility visibility = new ItemVisibility();
ChatAggConfigReq chatAggConfig = chatConfigInfo.getChatAggConfig();
ChatDetailConfigReq chatDetailConfig = chatConfigInfo.getChatDetailConfig();
// both black is exist
if (Objects.nonNull(chatAggConfig) && Objects.nonNull(chatAggConfig.getVisibility())
&& Objects.nonNull(chatDetailConfig) && Objects.nonNull(chatDetailConfig.getVisibility())) {
List<Long> blackDimIdList = new ArrayList<>();
blackDimIdList.addAll(chatAggConfig.getVisibility().getBlackDimIdList());
blackDimIdList.retainAll(chatDetailConfig.getVisibility().getBlackDimIdList());
List<Long> blackMetricIdList = new ArrayList<>();
blackMetricIdList.addAll(chatAggConfig.getVisibility().getBlackMetricIdList());
blackMetricIdList.retainAll(chatDetailConfig.getVisibility().getBlackMetricIdList());
visibility.setBlackDimIdList(blackDimIdList);
visibility.setBlackMetricIdList(blackMetricIdList);
}
return visibility;
return modelInfo.getModelInfo().getPrimaryKey();
}
public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo,

View File

@@ -4,8 +4,10 @@ import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
@@ -13,11 +15,18 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.QueryDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.persistence.repository.ChatRepository;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.Lists;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Comparator;
@@ -28,14 +37,6 @@ import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.Lists;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service("ChatService")
@Primary
@@ -56,7 +57,7 @@ public class ChatServiceImpl implements ChatService {
}
@Override
public Long getContextModel(Integer chatId) {
public Set<Long> getContextModel(Integer chatId) {
if (Objects.isNull(chatId)) {
return null;
}
@@ -65,8 +66,8 @@ public class ChatServiceImpl implements ChatService {
return null;
}
SemanticParseInfo originalSemanticParse = chatContext.getParseInfo();
if (Objects.nonNull(originalSemanticParse) && Objects.nonNull(originalSemanticParse.getModelId())) {
return originalSemanticParse.getModelId();
if (Objects.nonNull(originalSemanticParse) && Objects.nonNull(originalSemanticParse.getModel().getModelIds())) {
return originalSemanticParse.getModel().getModelIds();
}
return null;
}

View File

@@ -6,39 +6,31 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
import com.tencent.supersonic.chat.api.pojo.request.ItemNameVisibilityInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.Entity;
import com.tencent.supersonic.chat.api.pojo.request.ItemNameVisibilityInfo;
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatAggRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatAggRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDetailRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.EntityRichInfoResp;
import com.tencent.supersonic.chat.api.pojo.response.ItemVisibilityInfo;
import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.persistence.repository.ChatConfigRepository;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.persistence.repository.ChatConfigRepository;
import com.tencent.supersonic.chat.utils.ChatConfigHelper;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.VisibilityEvent;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
@@ -52,6 +44,13 @@ import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
@Slf4j
@Service
@@ -83,7 +82,6 @@ public class ConfigServiceImpl implements ConfigService {
public Long addConfig(ChatConfigBaseReq configBaseCmd, User user) {
log.info("[create model extend] object:{}", JsonUtil.toString(configBaseCmd, true));
duplicateCheck(configBaseCmd.getModelId());
permissionCheckLogic(configBaseCmd.getModelId(), user.getName());
ChatConfig chaConfig = chatConfigHelper.newChatConfig(configBaseCmd, user);
Long id = chatConfigRepository.createConfig(chaConfig);
applicationEventPublisher.publishEvent(new VisibilityEvent(this, chaConfig));
@@ -107,7 +105,6 @@ public class ConfigServiceImpl implements ConfigService {
configEditCmd.getModelId())) {
throw new RuntimeException("editConfig, id and modelId are not allowed to be empty at the same time");
}
permissionCheckLogic(configEditCmd.getModelId(), user.getName());
ChatConfig chaConfig = chatConfigHelper.editChatConfig(configEditCmd, user);
chatConfigRepository.updateConfig(chaConfig);
applicationEventPublisher.publishEvent(new VisibilityEvent(this, chaConfig));
@@ -164,14 +161,6 @@ public class ConfigServiceImpl implements ConfigService {
return itemNameVisibility;
}
/**
* model administrators have the right to modify related configuration information.
*/
private Boolean permissionCheckLogic(Long modelId, String staffName) {
// todo
return true;
}
@Override
public List<ChatConfigResp> search(ChatConfigFilter filter, User user) {
log.info("[search model extend] object:{}", JsonUtil.toString(filter, true));

View File

@@ -1,205 +0,0 @@
package com.tencent.supersonic.chat.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.service.ParseInfoService;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Service;
@Slf4j
@Service
public class ParserInfoServiceImpl implements ParseInfoService {
public void updateParseInfo(SemanticParseInfo parseInfo) {
SqlInfo sqlInfo = parseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectS2SQL();
if (StringUtils.isBlank(correctS2SQL)) {
return;
}
// if S2SQL equals correctS2SQL, than not update the parseInfo.
if (correctS2SQL.equals(sqlInfo.getS2SQL())) {
return;
}
List<FieldExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctS2SQL);
//set dataInfo
try {
if (!CollectionUtils.isEmpty(expressions)) {
DateConf dateInfo = getDateInfo(expressions);
if (dateInfo != null && parseInfo.getDateInfo() == null) {
parseInfo.setDateInfo(dateInfo);
}
}
} catch (Exception e) {
log.error("set dateInfo error :", e);
}
//set filter
try {
Map<String, SchemaElement> fieldNameToElement = getNameToElement(parseInfo.getModelId());
List<QueryFilter> result = getDimensionFilter(fieldNameToElement, expressions);
parseInfo.getDimensionFilters().addAll(result);
} catch (Exception e) {
log.error("set dimensionFilter error :", e);
}
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
if (Objects.isNull(semanticSchema)) {
return;
}
List<String> allFields = getFieldsExceptDate(SqlParserSelectHelper.getAllFields(sqlInfo.getCorrectS2SQL()));
Set<SchemaElement> metrics = getElements(parseInfo.getModelId(), allFields, semanticSchema.getMetrics());
parseInfo.setMetrics(metrics);
if (SqlParserSelectFunctionHelper.hasAggregateFunction(sqlInfo.getCorrectS2SQL())) {
parseInfo.setQueryType(QueryType.METRIC);
List<String> groupByFields = SqlParserSelectHelper.getGroupByFields(sqlInfo.getCorrectS2SQL());
List<String> groupByDimensions = getFieldsExceptDate(groupByFields);
parseInfo.setDimensions(
getElements(parseInfo.getModelId(), groupByDimensions, semanticSchema.getDimensions()));
} else {
parseInfo.setQueryType(QueryType.TAG);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(sqlInfo.getCorrectS2SQL());
List<String> selectDimensions = getFieldsExceptDate(selectFields);
parseInfo.setDimensions(
getElements(parseInfo.getModelId(), selectDimensions, semanticSchema.getDimensions()));
}
}
private Set<SchemaElement> getElements(Long modelId, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel())
&& allFields.contains(schemaElement.getName())
).collect(Collectors.toSet());
}
private List<String> getFieldsExceptDate(List<String> allFields) {
if (org.springframework.util.CollectionUtils.isEmpty(allFields)) {
return new ArrayList<>();
}
return allFields.stream()
.filter(entry -> !TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(entry))
.collect(Collectors.toList());
}
private List<QueryFilter> getDimensionFilter(Map<String, SchemaElement> fieldNameToElement,
List<FieldExpression> fieldExpressions) {
List<QueryFilter> result = Lists.newArrayList();
for (FieldExpression expression : fieldExpressions) {
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(expression.getFieldValue());
SchemaElement schemaElement = fieldNameToElement.get(expression.getFieldName());
if (Objects.isNull(schemaElement)) {
continue;
}
dimensionFilter.setName(schemaElement.getName());
dimensionFilter.setBizName(schemaElement.getBizName());
dimensionFilter.setElementID(schemaElement.getId());
FilterOperatorEnum operatorEnum = FilterOperatorEnum.getSqlOperator(expression.getOperator());
dimensionFilter.setOperator(operatorEnum);
dimensionFilter.setFunction(expression.getFunction());
result.add(dimensionFilter);
}
return result;
}
private DateConf getDateInfo(List<FieldExpression> fieldExpressions) {
List<FieldExpression> dateExpressions = fieldExpressions.stream()
.filter(expression -> TimeDimensionEnum.DAY.getChName().equalsIgnoreCase(expression.getFieldName()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(dateExpressions)) {
return null;
}
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateMode.BETWEEN);
FieldExpression firstExpression = dateExpressions.get(0);
FilterOperatorEnum firstOperator = FilterOperatorEnum.getSqlOperator(firstExpression.getOperator());
if (FilterOperatorEnum.EQUALS.equals(firstOperator) && Objects.nonNull(firstExpression.getFieldValue())) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
dateInfo.setDateMode(DateMode.BETWEEN);
return dateInfo;
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN,
FilterOperatorEnum.GREATER_THAN_EQUALS)) {
dateInfo.setStartDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString());
}
}
if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN,
FilterOperatorEnum.MINOR_THAN_EQUALS)) {
dateInfo.setEndDate(firstExpression.getFieldValue().toString());
if (hasSecondDate(dateExpressions)) {
dateInfo.setStartDate(dateExpressions.get(1).getFieldValue().toString());
}
}
return dateInfo;
}
private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator,
FilterOperatorEnum... operatorEnums) {
return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue()));
}
private boolean hasSecondDate(List<FieldExpression> dateExpressions) {
return dateExpressions.size() > 1 && Objects.nonNull(dateExpressions.get(1).getFieldValue());
}
protected Map<String, SchemaElement> getNameToElement(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dimensions = semanticSchema.getDimensions();
List<SchemaElement> metrics = semanticSchema.getMetrics();
List<SchemaElement> allElements = Lists.newArrayList();
allElements.addAll(dimensions);
allElements.addAll(metrics);
//support alias
return allElements.stream()
.filter(schemaElement -> schemaElement.getModel().equals(modelId))
.flatMap(schemaElement -> {
Set<Pair<String, SchemaElement>> result = new HashSet<>();
result.add(Pair.of(schemaElement.getName(), schemaElement));
List<String> aliasList = schemaElement.getAlias();
if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) {
for (String alias : aliasList) {
result.add(Pair.of(alias, schemaElement));
}
}
return result.stream();
})
.collect(Collectors.toMap(pair -> pair.getLeft(), pair -> pair.getRight(), (value1, value2) -> value2));
}
}

View File

@@ -14,9 +14,15 @@ import com.tencent.supersonic.chat.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.Arrays;
import java.util.Date;
@@ -25,13 +31,6 @@ import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Service
@Slf4j
public class PluginServiceImpl implements PluginService {

View File

@@ -252,7 +252,7 @@ public class QueryServiceImpl implements QueryService {
solvedQueryManager.saveSolvedQuery(SolvedQueryReq.builder().parseId(queryReq.getParseId())
.queryId(queryReq.getQueryId())
.agentId(chatQueryDO.getAgentId())
.modelId(parseInfo.getModelId())
.modelId(parseInfo.getModelClusterKey())
.queryText(queryReq.getQueryText()).build());
}

View File

@@ -14,21 +14,27 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SearchResult;
import com.tencent.supersonic.chat.mapper.MapperHelper;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.dictionary.ModelInfoStat;
import com.tencent.supersonic.chat.mapper.ModelWithSemanticType;
import com.tencent.supersonic.chat.mapper.MatchText;
import com.tencent.supersonic.chat.mapper.ModelWithSemanticType;
import com.tencent.supersonic.chat.mapper.SearchMatchStrategy;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.SearchService;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.dictionary.HanlpMapResult;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.ModelInfoStat;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import com.tencent.supersonic.knowledge.utils.NatureHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
@@ -41,11 +47,6 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
/**
@@ -152,15 +153,13 @@ public class SearchServiceImpl implements SearchService {
List<Long> possibleModels = NatureHelper.selectPossibleModels(originals);
Long contextModel = chatService.getContextModel(queryCtx.getChatId());
Set<Long> contextModel = chatService.getContextModel(queryCtx.getChatId());
log.debug("possibleModels:{},modelStat:{},contextModel:{}", possibleModels, modelStat, contextModel);
// If nothing is recognized or only metric are present, then add the contextModel.
if (nothingOrOnlyMetric(modelStat) && effectiveModel(contextModel)) {
List<Long> result = new ArrayList<>();
result.add(contextModel);
return result;
if (nothingOrOnlyMetric(modelStat)) {
return contextModel.stream().filter(modelId -> modelId > 0).collect(Collectors.toList());
}
return possibleModels;
}

View File

@@ -1,8 +1,5 @@
package com.tencent.supersonic.chat.utils;
import static com.tencent.supersonic.common.pojo.Constants.DAY;
import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
@@ -11,12 +8,20 @@ import com.tencent.supersonic.chat.api.pojo.request.KnowledgeAdvancedConfig;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
import com.tencent.supersonic.chat.config.Dim4Dict;
import com.tencent.supersonic.chat.config.DefaultMetric;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.config.Dim4Dict;
import com.tencent.supersonic.chat.persistence.dataobject.DimValueDO;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.knowledge.dictionary.DictUpdateMode;
import com.tencent.supersonic.knowledge.dictionary.DimValue2DictCommand;
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Arrays;
@@ -28,14 +33,8 @@ import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import static com.tencent.supersonic.common.pojo.Constants.DAY;
import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE;
@Component
public class DictMetaHelper {
@@ -202,7 +201,7 @@ public class DictMetaHelper {
if (Objects.nonNull(dimIdAndRespPair)
&& dimIdAndRespPair.containsKey(dim4Dict.getDimId())) {
String datasourceFilterSql = dimIdAndRespPair.get(
dim4Dict.getDimId()).getDatasourceFilterSql();
dim4Dict.getDimId()).getModelFilterSql();
if (StringUtils.isNotEmpty(datasourceFilterSql)) {
dim4Dict.getRuleList().add(datasourceFilterSql);
}
@@ -241,7 +240,7 @@ public class DictMetaHelper {
PageInfo<DimensionResp> dimensionPage = semanticInterpreter.getDimensionPage(pageDimensionCmd);
if (Objects.nonNull(dimensionPage) && !CollectionUtils.isEmpty(dimensionPage.getList())) {
List<DimensionResp> list = dimensionPage.getList();
return list.get(0).getDatasourceBizName();
return list.get(0).getModelBizName();
}
return "";
}

View File

@@ -0,0 +1,42 @@
package com.tencent.supersonic.chat.utils;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.pojo.ModelCluster;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
public class ModelClusterBuilder {
public static List<ModelCluster> buildModelClusters(SemanticSchema semanticSchema) {
Map<Long, ModelSchema> modelMap = semanticSchema.getModelSchemaMap();
Set<Long> visited = new HashSet<>();
List<Set<Long>> modelClusters = new ArrayList<>();
for (ModelSchema model : modelMap.values()) {
if (!visited.contains(model.getModel().getModel())) {
Set<Long> modelCluster = new HashSet<>();
dfs(model, modelMap, visited, modelCluster);
modelClusters.add(modelCluster);
}
}
return modelClusters.stream().map(ModelCluster::build).collect(Collectors.toList());
}
private static void dfs(ModelSchema model, Map<Long, ModelSchema> modelMap,
Set<Long> visited, Set<Long> modelCluster) {
visited.add(model.getModel().getModel());
modelCluster.add(model.getModel().getModel());
for (Long neighborId : model.getModelClusterSet()) {
if (!visited.contains(neighborId)) {
dfs(modelMap.get(neighborId), modelMap, visited, modelCluster);
}
}
}
}

View File

@@ -16,6 +16,12 @@ import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.Arrays;
@@ -25,18 +31,13 @@ import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
@Slf4j
public class QueryReqBuilder {
public static QueryStructReq buildStructReq(SemanticParseInfo parseInfo) {
QueryStructReq queryStructCmd = new QueryStructReq();
queryStructCmd.setModelId(parseInfo.getModelId());
queryStructCmd.setModelIds(parseInfo.getModel().getModelIds());
queryStructCmd.setQueryType(parseInfo.getQueryType());
queryStructCmd.setDateInfo(rewrite2Between(parseInfo.getDateInfo()));
@@ -128,15 +129,15 @@ public class QueryReqBuilder {
* convert to QueryS2SQLReq
*
* @param querySql
* @param modelId
* @param modelIds
* @return
*/
public static QueryS2SQLReq buildS2SQLReq(String querySql, Long modelId) {
public static QueryS2SQLReq buildS2SQLReq(String querySql, Set<Long> modelIds) {
QueryS2SQLReq queryS2SQLReq = new QueryS2SQLReq();
if (Objects.nonNull(querySql)) {
queryS2SQLReq.setSql(querySql);
}
queryS2SQLReq.setModelId(modelId);
queryS2SQLReq.setModelIds(modelIds);
return queryS2SQLReq;
}

View File

@@ -6,8 +6,11 @@ import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.RelateSchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Set;
@@ -17,7 +20,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_necessaryDimension_groupBy() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, sum(访问次数) FROM 超音数 GROUP BY 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@@ -26,7 +30,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_necessaryDimension_where() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 where 部门 = 'HR' group by 用户名";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 "
+ "WHERE 部门 = 'HR' GROUP BY 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
@@ -36,7 +41,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_dimensionNotDrillDown_groupBy() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 页面, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 部门";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@@ -45,7 +51,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_dimensionNotDrillDown_where() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 where 页面 = 'P1' group by 部门";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@@ -54,7 +61,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_dimensionNotDrillDown_necessaryDimension() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 页面, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT sum(访问次数) FROM 超音数";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@@ -63,7 +71,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_dimensionDrillDown() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 用户名, 部门, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 用户名, 部门";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql, mockModelSchema());
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo, mockModelSchema());
String expectedProcessedSql = "SELECT 用户名, 部门, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 用户名, 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
}
@@ -72,7 +81,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_noDrillDownDimensionSetting() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 页面, 用户名, sum(访问次数), count(distinct 访问用户数) from 超音数 group by 页面, 用户名";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "SELECT 页面, 用户名, sum(访问次数), count(DISTINCT 访问用户数) FROM 超音数 GROUP BY 页面, 用户名";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
@@ -82,7 +92,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_noDrillDownDimensionSetting_noAgg() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 访问次数 from 超音数";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "select 访问次数 from 超音数";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
@@ -92,7 +103,8 @@ class MetricCheckPostProcessorTest {
void testProcessCorrectSql_noDrillDownDimensionSetting_count() {
MetricCheckPostProcessor metricCheckPostProcessor = new MetricCheckPostProcessor();
String correctSql = "select 部门, count(*) from 超音数 group by 部门";
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(correctSql,
SemanticParseInfo parseInfo = mockParseInfo(correctSql);
String actualProcessedSql = metricCheckPostProcessor.processCorrectSql(parseInfo,
mockModelSchemaNoDimensionSetting());
String expectedProcessedSql = "select 部门, count(*) from 超音数 group by 部门";
Assertions.assertEquals(expectedProcessedSql, actualProcessedSql);
@@ -102,7 +114,7 @@ class MetricCheckPostProcessorTest {
* 访问次数 drill down dimension is 用户名 and 部门
* 访问用户数 drill down dimension is 部门, and 部门 is necessary, 部门 need in select and group by or where expressions
*/
private ModelSchema mockModelSchema() {
private SemanticSchema mockModelSchema() {
ModelSchema modelSchema = new ModelSchema();
Set<SchemaElement> metrics = Sets.newHashSet(
mockElement(1L, "访问次数", SchemaElementType.METRIC,
@@ -113,10 +125,10 @@ class MetricCheckPostProcessorTest {
);
modelSchema.setMetrics(metrics);
modelSchema.setDimensions(mockDimensions());
return modelSchema;
return new SemanticSchema(Lists.newArrayList(modelSchema));
}
private ModelSchema mockModelSchemaNoDimensionSetting() {
private SemanticSchema mockModelSchemaNoDimensionSetting() {
ModelSchema modelSchema = new ModelSchema();
Set<SchemaElement> metrics = Sets.newHashSet(
mockElement(1L, "访问次数", SchemaElementType.METRIC, Lists.newArrayList()),
@@ -124,7 +136,7 @@ class MetricCheckPostProcessorTest {
);
modelSchema.setMetrics(metrics);
modelSchema.setDimensions(mockDimensions());
return modelSchema;
return new SemanticSchema(Lists.newArrayList(modelSchema));
}
private Set<SchemaElement> mockDimensions() {
@@ -141,4 +153,10 @@ class MetricCheckPostProcessorTest {
.relateSchemaElements(relateSchemaElements).build();
}
private SemanticParseInfo mockParseInfo(String correctSql) {
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctSql);
return semanticParseInfo;
}
}

View File

@@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter;
import com.tencent.supersonic.chat.test.ChatBizLauncher;
import com.tencent.supersonic.semantic.model.domain.DimensionService;
import com.tencent.supersonic.semantic.model.domain.ModelService;
import com.tencent.supersonic.semantic.model.domain.MetricService;
import com.tencent.supersonic.semantic.model.domain.ModelService;
import com.tencent.supersonic.semantic.query.service.QueryService;
import org.junit.runner.RunWith;
import org.slf4j.Logger;

View File

@@ -1,32 +1,24 @@
package com.tencent.supersonic.chat.test.context;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.when;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.config.DefaultMetric;
import com.tencent.supersonic.chat.config.DefaultMetricInfo;
import com.tencent.supersonic.chat.config.EntityInternalDetail;
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.chat.persistence.repository.impl.ChatContextRepositoryImpl;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.service.impl.ConfigServiceImpl;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp;
import com.tencent.supersonic.chat.service.impl.ConfigServiceImpl;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.model.domain.DimensionService;
import com.tencent.supersonic.semantic.model.domain.ModelService;
import com.tencent.supersonic.semantic.model.domain.MetricService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import com.tencent.supersonic.semantic.model.domain.ModelService;
import com.tencent.supersonic.semantic.model.domain.pojo.DimensionFilter;
import com.tencent.supersonic.semantic.model.domain.pojo.MetaFilter;
import org.mockito.Mockito;
@@ -34,6 +26,14 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.when;
@Configuration
public class MockBeansConfiguration {

View File

@@ -17,19 +17,20 @@ import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import com.tencent.supersonic.semantic.model.domain.DimensionService;
import com.tencent.supersonic.semantic.model.domain.MetricService;
import com.tencent.supersonic.semantic.query.service.QueryService;
import com.tencent.supersonic.semantic.query.service.SchemaService;
import java.util.HashMap;
import java.util.List;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import java.util.HashMap;
import java.util.List;
@Slf4j
public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
@@ -44,7 +45,7 @@ public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
if (StringUtils.isNotBlank(queryStructReq.getCorrectS2SQL())) {
QueryS2SQLReq queryS2SQLReq = new QueryS2SQLReq();
queryS2SQLReq.setSql(queryStructReq.getCorrectS2SQL());
queryS2SQLReq.setModelId(queryStructReq.getModelId());
queryS2SQLReq.setModelIds(queryStructReq.getModelIdSet());
queryS2SQLReq.setVariables(new HashMap<>());
return queryByS2SQL(queryS2SQLReq, user);
}

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.semantic.api.model.pojo.DimValueMap;
import com.tencent.supersonic.semantic.api.model.pojo.Entity;
import com.tencent.supersonic.semantic.api.model.pojo.RelateDimension;
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp;
@@ -20,8 +19,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@@ -38,6 +35,7 @@ public class ModelSchemaBuilder {
.alias(SchemaItem.getAliasList(resp.getAlias()))
.build();
modelSchema.setModel(domain);
modelSchema.setModelRelas(resp.getModelRelas());
Set<SchemaElement> metrics = new HashSet<>();
for (MetricSchemaResp metric : resp.getMetrics()) {
@@ -124,23 +122,19 @@ public class ModelSchemaBuilder {
modelSchema.getDimensionValues().addAll(dimensionValues);
modelSchema.getTags().addAll(tags);
Entity entity = resp.getEntity();
if (Objects.nonNull(entity)) {
SchemaElement entityElement = new SchemaElement();
if (!CollectionUtils.isEmpty(entity.getNames()) && Objects.nonNull(entity.getEntityId())) {
Map<Long, SchemaElement> idAndDimPair = dimensions.stream()
.collect(
Collectors.toMap(SchemaElement::getId, schemaElement -> schemaElement, (k1, k2) -> k2));
if (idAndDimPair.containsKey(entity.getEntityId())) {
BeanUtils.copyProperties(idAndDimPair.get(entity.getEntityId()), entityElement);
entityElement.setType(SchemaElementType.ENTITY);
}
entityElement.setAlias(entity.getNames());
modelSchema.setEntity(entityElement);
}
DimSchemaResp dim = resp.getPrimaryKey();
if (dim != null) {
SchemaElement entity = SchemaElement.builder()
.model(resp.getId())
.id(dim.getId())
.name(dim.getName())
.bizName(dim.getBizName())
.type(SchemaElementType.ENTITY)
.useCnt(dim.getUseCnt())
.alias(dim.getEntityAlias())
.build();
modelSchema.setEntity(entity);
}
return modelSchema;
}

View File

@@ -1,10 +1,5 @@
package com.tencent.supersonic.knowledge.semantic;
import static com.tencent.supersonic.common.pojo.Constants.LIST_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.PAGESIZE_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.TOTAL_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.TRUE_LOWER;
import com.alibaba.fastjson.JSON;
import com.github.pagehelper.PageInfo;
import com.google.gson.Gson;
@@ -31,15 +26,9 @@ import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq;
import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryS2SQLReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.net.URI;
import java.net.URL;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
@@ -53,6 +42,18 @@ import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.net.URL;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Objects;
import static com.tencent.supersonic.common.pojo.Constants.LIST_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.PAGESIZE_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.TOTAL_LOWER;
import static com.tencent.supersonic.common.pojo.Constants.TRUE_LOWER;
@Slf4j
public class RemoteSemanticInterpreter extends BaseSemanticInterpreter {
@@ -73,7 +74,7 @@ public class RemoteSemanticInterpreter extends BaseSemanticInterpreter {
if (StringUtils.isNotBlank(queryStructReq.getCorrectS2SQL())) {
QueryS2SQLReq queryS2SQLReq = new QueryS2SQLReq();
queryS2SQLReq.setSql(queryStructReq.getCorrectS2SQL());
queryS2SQLReq.setModelId(queryStructReq.getModelId());
queryS2SQLReq.setModelIds(queryStructReq.getModelIdSet());
queryS2SQLReq.setVariables(new HashMap<>());
return queryByS2SQL(queryS2SQLReq, user);
}