[improvement][project] supersonic 0.7.2 version backend update (#28)

Co-authored-by: jipengli <jipengli@tencent.com>
This commit is contained in:
jipeli
2023-08-15 08:56:18 +08:00
committed by GitHub
parent 27283001a8
commit b1952d64ab
461 changed files with 18548 additions and 11939 deletions

View File

@@ -3,11 +3,13 @@ package com.tencent.supersonic.chat.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class AggregatorConfig {
@Value("${metric.aggregator.ratio.enable:true}")
private Boolean enableRatio;
@Value("${metric.aggregator.ratio.enable:true}")
private Boolean enableRatio;
}

View File

@@ -3,13 +3,12 @@ package com.tencent.supersonic.chat.config;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import java.util.List;
import lombok.Data;
import lombok.ToString;
import java.util.List;
@Data
@ToString
public class ChatConfig {
@@ -19,15 +18,15 @@ public class ChatConfig {
*/
private Long id;
private Long domainId;
private Long modelId;
/**
* the chatDetailConfig about the domain
* the chatDetailConfig about the model
*/
private ChatDetailConfigReq chatDetailConfig;
/**
* the chatAggConfig about the domain
* the chatAggConfig about the model
*/
private ChatAggConfigReq chatAggConfig;

View File

@@ -6,6 +6,6 @@ import lombok.Data;
public class ChatConfigFilterInternal {
private Long id;
private Long domainId;
private Long modelId;
private Integer status;
}

View File

@@ -6,7 +6,7 @@ import lombok.ToString;
/**
* default metrics about the domain
* default metrics about the model
*/
@ToString

View File

@@ -7,6 +7,7 @@ import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class FunctionCallInfoConfig {
@Value("${functionCall.url:}")
private String url;
}

View File

@@ -1,21 +0,0 @@
package com.tencent.supersonic.chat.mapper;
import java.io.Serializable;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
@Builder
public class DomainInfoStat implements Serializable {
private long domainCount;
private long metricDomainCount;
private long dimensionDomainCount;
private long dimensionValueDomainCount;
}

View File

@@ -2,14 +2,19 @@ package com.tencent.supersonic.chat.mapper;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
@@ -18,20 +23,20 @@ public class EntityMapper implements SchemaMapper {
@Override
public void map(QueryContext queryContext) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
for (Long domainId : schemaMapInfo.getMatchedDomains()) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(domainId);
for (Long modelId : schemaMapInfo.getMatchedModels()) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
continue;
}
SchemaElement entity = getEntity(domainId);
SchemaElement entity = getEntity(modelId);
if (entity == null || entity.getId() == null) {
continue;
}
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (!entity.getId().equals(schemaElementMatch.getElement().getId())){
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
continue;
}
if (!checkExistSameEntitySchemaElements(schemaElementMatch, schemaElementMatchList)) {
@@ -46,7 +51,7 @@ public class EntityMapper implements SchemaMapper {
}
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
List<SchemaElementMatch> schemaElementMatchList) {
List<SchemaElementMatch> schemaElementMatchList) {
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
@@ -58,11 +63,11 @@ public class EntityMapper implements SchemaMapper {
return false;
}
private SchemaElement getEntity(Long domainId) {
private SchemaElement getEntity(Long modelId) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = semanticService.getDomainSchema(domainId);
if (domainSchema != null && domainSchema.getEntity() != null) {
return domainSchema.getEntity();
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
if (modelSchema != null && modelSchema.getEntity() != null) {
return modelSchema.getEntity();
}
return null;
}

View File

@@ -2,10 +2,14 @@ package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import java.util.ArrayList;
import java.util.Comparator;
@@ -39,13 +43,13 @@ public class FuzzyNameMapper implements SchemaMapper {
log.debug("after db mapper,mapInfo:{}", queryContext.getMapInfo());
}
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> domains,
SchemaElementType schemaElementType) {
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> Models,
SchemaElementType schemaElementType) {
try {
Map<String, Set<SchemaElement>> domainResultSet = getResultSet(queryContext, terms, domains);
Map<String, Set<SchemaElement>> ModelResultSet = getResultSet(queryContext, terms, Models);
addToSchemaMapInfo(domainResultSet, queryContext.getMapInfo(), schemaElementType);
addToSchemaMapInfo(ModelResultSet, queryContext.getMapInfo(), schemaElementType);
} catch (Exception e) {
log.error("detectAndAddToSchema error", e);
@@ -53,7 +57,7 @@ public class FuzzyNameMapper implements SchemaMapper {
}
private Map<String, Set<SchemaElement>> getResultSet(QueryContext queryContext, List<Term> terms,
List<SchemaElement> domains) {
List<SchemaElement> Models) {
String queryText = queryContext.getRequest().getQueryText();
@@ -61,12 +65,12 @@ public class FuzzyNameMapper implements SchemaMapper {
Double metricDimensionThresholdConfig = getThreshold(queryContext, mapperHelper);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(domains);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(Models);
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
Map<String, Set<SchemaElement>> domainResultSet = new HashMap<>();
Map<String, Set<SchemaElement>> ModelResultSet = new HashMap<>();
for (Integer startIndex = 0; startIndex <= queryText.length() - 1; ) {
for (Integer endIndex = startIndex; endIndex <= queryText.length(); ) {
endIndex = mapperHelper.getStepIndex(regOffsetToLength, endIndex);
@@ -82,7 +86,7 @@ public class FuzzyNameMapper implements SchemaMapper {
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
continue;
}
Set<SchemaElement> preSchemaElements = domainResultSet.putIfAbsent(detectSegment,
Set<SchemaElement> preSchemaElements = ModelResultSet.putIfAbsent(detectSegment,
schemaElements);
if (Objects.nonNull(preSchemaElements)) {
preSchemaElements.addAll(schemaElements);
@@ -91,7 +95,7 @@ public class FuzzyNameMapper implements SchemaMapper {
}
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
}
return domainResultSet;
return ModelResultSet;
}
private Double getThreshold(QueryContext queryContext, MapperHelper mapperHelper) {
@@ -99,9 +103,9 @@ public class FuzzyNameMapper implements SchemaMapper {
Double metricDimensionThresholdConfig = mapperHelper.getMetricDimensionThresholdConfig();
Double metricDimensionMinThresholdConfig = mapperHelper.getMetricDimensionMinThresholdConfig();
Map<Long, List<SchemaElementMatch>> domainElementMatches = queryContext.getMapInfo()
.getDomainElementMatches();
boolean existElement = domainElementMatches.entrySet().stream()
Map<Long, List<SchemaElementMatch>> ModelElementMatches = queryContext.getMapInfo()
.getModelElementMatches();
boolean existElement = ModelElementMatches.entrySet().stream()
.anyMatch(entry -> entry.getValue().size() >= 1);
if (!existElement) {
@@ -109,14 +113,14 @@ public class FuzzyNameMapper implements SchemaMapper {
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
: metricDimensionMinThresholdConfig;
log.info("domainElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
domainElementMatches, metricDimensionThresholdConfig);
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
ModelElementMatches, metricDimensionThresholdConfig);
}
return metricDimensionThresholdConfig;
}
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> domains) {
return domains.stream().collect(
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> Models) {
return Models.stream().collect(
Collectors.toMap(SchemaElement::getName, a -> {
Set<SchemaElement> result = new HashSet<>();
result.add(a);
@@ -139,10 +143,10 @@ public class FuzzyNameMapper implements SchemaMapper {
Set<SchemaElement> schemaElements = entry.getValue();
for (SchemaElement schemaElement : schemaElements) {
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDomain());
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
if (CollectionUtils.isEmpty(elements)) {
elements = new ArrayList<>();
schemaMap.setMatchedElements(schemaElement.getDomain(), elements);
schemaMap.setMatchedElements(schemaElement.getModel(), elements);
}
Set<Long> regElementSet = elements.stream()
.filter(elementMatch -> schemaElementType.equals(elementMatch.getElement().getType()))

View File

@@ -2,13 +2,18 @@ package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.NatureHelper;
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
import com.tencent.supersonic.knowledge.dictionary.builder.WordBuilderFactory;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import java.util.ArrayList;
@@ -32,10 +37,10 @@ public class HanlpDictMapper implements SchemaMapper {
for (Term term : terms) {
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
Long domainId = queryContext.getRequest().getDomainId();
Long modelId = queryContext.getRequest().getModelId();
QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class);
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryText, terms, domainId);
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryText, terms, modelId);
List<MapResult> matches = getMatches(matchResult);
@@ -57,8 +62,8 @@ public class HanlpDictMapper implements SchemaMapper {
for (MapResult mapResult : mapResults) {
for (String nature : mapResult.getNatures()) {
Long domainId = NatureHelper.getDomainId(nature);
if (Objects.isNull(domainId)) {
Long modelId = NatureHelper.getModelId(nature);
if (Objects.isNull(modelId)) {
continue;
}
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
@@ -67,14 +72,14 @@ public class HanlpDictMapper implements SchemaMapper {
}
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = schemaService.getDomainSchema(domainId);
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
BaseWordBuilder baseWordBuilder = WordBuilderFactory.get(DictWordType.getNatureType(nature));
Long elementID = baseWordBuilder.getElementID(nature);
Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature);
SchemaElement element = domainSchema.getElement(elementType, elementID);
if(Objects.isNull(element)){
SchemaElement element = modelSchema.getElement(elementType, elementID);
if (Objects.isNull(element)) {
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
continue;
}
@@ -89,11 +94,11 @@ public class HanlpDictMapper implements SchemaMapper {
.detectWord(mapResult.getDetectWord())
.build();
Map<Long, List<SchemaElementMatch>> domainElementMatches = schemaMap.getDomainElementMatches();
List<SchemaElementMatch> schemaElementMatches = domainElementMatches.putIfAbsent(domainId,
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId,
new ArrayList<>());
if (schemaElementMatches == null) {
schemaElementMatches = domainElementMatches.get(domainId);
schemaElementMatches = modelElementMatches.get(modelId);
}
schemaElementMatches.add(schemaElementMatch);
}

View File

@@ -83,4 +83,6 @@ public class MapperHelper {
return 1 - (double) EditDistance.compute(detectSegmentLower, matchNameLower) / Math.max(matchName.length(),
detectSegment.length());
}
}

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import java.util.List;
import java.util.Map;
@@ -11,6 +10,6 @@ import java.util.Map;
*/
public interface MatchStrategy {
Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectDomainId);
Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectModelId);
}

View File

@@ -0,0 +1,21 @@
package com.tencent.supersonic.chat.mapper;
import java.io.Serializable;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
@Builder
public class ModelInfoStat implements Serializable {
private long modelCount;
private long metricModelCount;
private long dimensionModelCount;
private long dimensionValueModelCount;
}

View File

@@ -7,13 +7,13 @@ import lombok.ToString;
@Data
@ToString
public class DomainWithSemanticType implements Serializable {
public class ModelWithSemanticType implements Serializable {
private Long domain;
private Long model;
private SchemaElementType semanticType;
public DomainWithSemanticType(Long domain, SchemaElementType semanticType) {
this.domain = domain;
public ModelWithSemanticType(Long model, SchemaElementType semanticType) {
this.model = model;
this.semanticType = semanticType;
}
}

View File

@@ -2,15 +2,20 @@ package com.tencent.supersonic.chat.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import com.tencent.supersonic.common.pojo.Constants;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class QueryFilterMapper implements SchemaMapper {
@@ -21,35 +26,35 @@ public class QueryFilterMapper implements SchemaMapper {
@Override
public void map(QueryContext queryContext) {
QueryReq queryReq = queryContext.getRequest();
Long domainId = queryReq.getDomainId();
if (domainId == null || domainId <= 0) {
Long modelId = queryReq.getModelId();
if (modelId == null || modelId <= 0) {
return;
}
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
clearOtherSchemaElementMatch(domainId, schemaMapInfo);
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(domainId);
clearOtherSchemaElementMatch(modelId, schemaMapInfo);
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(modelId);
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
schemaMapInfo.setMatchedElements(domainId, schemaElementMatches);
schemaMapInfo.setMatchedElements(modelId, schemaElementMatches);
}
addValueSchemaElementMatch(schemaElementMatches, queryReq.getQueryFilters());
}
private void clearOtherSchemaElementMatch(Long domainId, SchemaMapInfo schemaMapInfo) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getDomainElementMatches().entrySet()) {
if (!entry.getKey().equals(domainId)) {
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getModelElementMatches().entrySet()) {
if (!entry.getKey().equals(modelId)) {
entry.getValue().clear();
}
}
}
private List<SchemaElementMatch> addValueSchemaElementMatch(List<SchemaElementMatch> candidateElementMatches,
QueryFilters queryFilter) {
QueryFilters queryFilter) {
if (queryFilter == null || CollectionUtils.isEmpty(queryFilter.getFilters())) {
return candidateElementMatches;
}
for (QueryFilter filter : queryFilter.getFilters()) {
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
continue;
}
SchemaElement element = SchemaElement.builder()
@@ -63,7 +68,7 @@ public class QueryFilterMapper implements SchemaMapper {
.frequency(FREQUENCY)
.word(String.valueOf(filter.getValue()))
.similarity(SIMILARITY)
.detectWord(filter.getName())
.detectWord(Constants.EMPTY)
.build();
candidateElementMatches.add(schemaElementMatch);
}
@@ -71,13 +76,13 @@ public class QueryFilterMapper implements SchemaMapper {
}
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
List<SchemaElementMatch> schemaElementMatches) {
List<SchemaElementMatch> schemaElementMatches) {
List<SchemaElementMatch> valueSchemaElements = schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (schemaElementMatch.getElement().getId().equals(queryFilter.getElementID())
&& schemaElementMatch.getWord().equals(String.valueOf(queryFilter.getValue()))) {
&& schemaElementMatch.getWord().equals(String.valueOf(queryFilter.getValue()))) {
return true;
}
}

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
@@ -32,7 +32,7 @@ public class QueryMatchStrategy implements MatchStrategy {
private MapperHelper mapperHelper;
@Override
public Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectDomainId) {
public Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectmodelId) {
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
@@ -43,10 +43,10 @@ public class QueryMatchStrategy implements MatchStrategy {
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
.map(term -> term.getOffset()).collect(Collectors.toList());
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectDomainId:{}", terms,
regOffsetToLength, offsetList, detectDomainId);
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectmodelId:{}", terms,
regOffsetToLength, offsetList, detectmodelId);
List<MapResult> detects = detect(text, regOffsetToLength, offsetList, detectDomainId);
List<MapResult> detects = detect(text, regOffsetToLength, offsetList, detectmodelId);
Map<MatchText, List<MapResult>> result = new HashMap<>();
MatchText matchText = MatchText.builder()
@@ -58,7 +58,7 @@ public class QueryMatchStrategy implements MatchStrategy {
}
private List<MapResult> detect(String text, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
Long detectDomainId) {
Long detectmodelId) {
List<MapResult> results = Lists.newArrayList();
for (Integer index = 0; index <= text.length() - 1; ) {
@@ -69,18 +69,44 @@ public class QueryMatchStrategy implements MatchStrategy {
int offset = mapperHelper.getStepOffset(offsetList, index);
i = mapperHelper.getStepIndex(regOffsetToLength, i);
if (i <= text.length()) {
List<MapResult> mapResults = detectByStep(text, detectDomainId, index, i, offset);
mapResultRowSet.addAll(mapResults);
List<MapResult> mapResults = detectByStep(text, detectmodelId, index, i, offset);
selectMapResultInOneRound(mapResultRowSet, mapResults);
}
}
index = mapperHelper.getStepIndex(regOffsetToLength, index);
results.addAll(mapResultRowSet);
}
return results;
}
private List<MapResult> detectByStep(String text, Long detectDomainId, Integer index, Integer i, int offset) {
private void selectMapResultInOneRound(Set<MapResult> mapResultRowSet, List<MapResult> mapResults) {
for (MapResult mapResult : mapResults) {
if (mapResultRowSet.contains(mapResult)) {
boolean isDeleted = mapResultRowSet.removeIf(
entry -> {
boolean deleted = getMapKey(mapResult).equals(getMapKey(entry))
&& entry.getDetectWord().length() < mapResult.getDetectWord().length();
if (deleted) {
log.info("deleted entry:{}", entry);
}
return deleted;
}
);
if (isDeleted) {
log.info("deleted, add mapResult:{}", mapResult);
mapResultRowSet.add(mapResult);
}
} else {
mapResultRowSet.add(mapResult);
}
}
}
private String getMapKey(MapResult a) {
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
}
private List<MapResult> detectByStep(String text, Long detectmodelId, Integer index, Integer i, int offset) {
String detectSegment = text.substring(index, i);
Integer oneDetectionSize = mapperHelper.getOneDetectionSize();
// step1. pre search
@@ -100,17 +126,17 @@ public class QueryMatchStrategy implements MatchStrategy {
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toCollection(LinkedHashSet::new));
// step4. filter by classId
if (Objects.nonNull(detectDomainId) && detectDomainId > 0) {
log.debug("detectDomainId:{}, before parseResults:{}", mapResults);
if (Objects.nonNull(detectmodelId) && detectmodelId > 0) {
log.debug("detectmodelId:{}, before parseResults:{}", mapResults);
mapResults = mapResults.stream().map(entry -> {
List<String> natures = entry.getNatures().stream().filter(
nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectDomainId) || (nature.startsWith(
nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectmodelId) || (nature.startsWith(
DictWordType.NATURE_SPILT))
).collect(Collectors.toList());
entry.setNatures(natures);
return entry;
}).collect(Collectors.toCollection(LinkedHashSet::new));
log.info("after domainId parseResults:{}", mapResults);
log.info("after modelId parseResults:{}", mapResults);
}
// step5. filter by similarity
mapResults = mapResults.stream()

View File

@@ -2,10 +2,9 @@ package com.tencent.supersonic.chat.mapper;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -25,7 +24,7 @@ public class SearchMatchStrategy implements MatchStrategy {
@Override
public Map<MatchText, List<MapResult>> match(String text, List<Term> originals,
Long detectDomainId) {
Long detectModelId) {
Map<Integer, Integer> regOffsetToLength = originals.stream()
.filter(entry -> !entry.nature.toString().startsWith(DictWordType.NATURE_SPILT))
@@ -60,10 +59,10 @@ public class SearchMatchStrategy implements MatchStrategy {
List<String> natures = entry.getNatures().stream()
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
.filter(nature -> {
if (Objects.isNull(detectDomainId) || detectDomainId <= 0) {
if (Objects.isNull(detectModelId) || detectModelId <= 0) {
return true;
}
if (nature.startsWith(DictWordType.NATURE_SPILT + detectDomainId)
if (nature.startsWith(DictWordType.NATURE_SPILT + detectModelId)
&& nature.startsWith(DictWordType.NATURE_SPILT)) {
return true;
}

View File

@@ -2,22 +2,9 @@ package com.tencent.supersonic.chat.parser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.Lists;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* This checker can be used by semantic parsers to check if query intent
@@ -30,86 +17,20 @@ public class SatisfactionChecker {
private static final double LONG_TEXT_THRESHOLD = 0.8;
private static final double SHORT_TEXT_THRESHOLD = 0.5;
private static final int QUERY_TEXT_LENGTH_THRESHOLD = 10;
public static final double EMBEDDING_THRESHOLD = 0.2;
// check all the parse info in candidate
public static boolean check(QueryContext queryCtx) {
for (SemanticQuery query : queryCtx.getCandidateQueries()) {
if (query instanceof RuleSemanticQuery) {
if (checkRuleThreshHold(queryCtx.getRequest().getQueryText(), query.getParseInfo())) {
return true;
}
} else if (query instanceof PluginSemanticQuery) {
if (checkEmbeddingThreshold(query.getParseInfo())) {
log.info("query mode :{} satisfy check", query.getQueryMode());
return true;
}
public static boolean check(QueryContext queryContext) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
return true;
}
}
return false;
}
private static boolean checkEmbeddingThreshold(SemanticParseInfo semanticParseInfo) {
Object object = semanticParseInfo.getProperties().get(Constants.CONTEXT);
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(object), PluginParseResult.class);
return EMBEDDING_THRESHOLD > pluginParseResult.getDistance();
}
//check single parse info
private static boolean checkRuleThreshHold(String text, SemanticParseInfo semanticParseInfo) {
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
if (CollectionUtils.isEmpty(schemaElementMatches)) {
return false;
}
List<String> detectWords = Lists.newArrayList();
Map<Long, String> detectWordMap = schemaElementMatches.stream()
.collect(Collectors.toMap(m -> m.getElement().getId(), SchemaElementMatch::getDetectWord,
(id1, id2) -> id1));
// get detect word in text by element id in semantic layer
Long domainId = semanticParseInfo.getDomainId();
if (domainId != null && domainId > 0) {
for (SchemaElementMatch schemaElementMatch : schemaElementMatches) {
if (SchemaElementType.DOMAIN.equals(schemaElementMatch.getElement().getType())) {
detectWords.add(schemaElementMatch.getDetectWord());
}
}
}
for (SchemaElementMatch schemaElementMatch : schemaElementMatches) {
if (SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())) {
detectWords.add(schemaElementMatch.getDetectWord());
}
}
for (SchemaElementMatch schemaElementMatch : schemaElementMatches) {
if (SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())) {
detectWords.add(schemaElementMatch.getDetectWord());
}
}
for (SchemaElement schemaItem : semanticParseInfo.getMetrics()) {
detectWords.add(
detectWordMap.getOrDefault(Optional.ofNullable(schemaItem.getId()).orElse(0L), ""));
// only first metric
break;
}
for (SchemaElement schemaItem : semanticParseInfo.getDimensions()) {
detectWords.add(
detectWordMap.getOrDefault(Optional.ofNullable(schemaItem.getId()).orElse(0L), ""));
// only first dimension
break;
}
String dateText = Optional.ofNullable(semanticParseInfo.getDateInfo()).orElse(new DateConf()).getText();
if (StringUtils.isNotBlank(dateText) && !dateText.equalsIgnoreCase(Constants.NULL)) {
detectWords.add(dateText);
}
detectWords.removeIf(word -> !text.contains(word) && !text.contains(StringUtils.reverse(word)));
//compare the length between detect words and query text
return checkThreshold(text, detectWords, semanticParseInfo);
}
private static boolean checkThreshold(String queryText, List<String> detectWords, SemanticParseInfo semanticParseInfo) {
String detectWordsDistinct = StringUtils.join(new HashSet<>(detectWords), "");
int detectWordsLength = detectWordsDistinct.length();
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
int queryTextLength = queryText.length();
double degree = detectWordsLength * 1.0 / queryTextLength;
double degree = semanticParseInfo.getScore() / queryTextLength;
if (queryTextLength > QUERY_TEXT_LENGTH_THRESHOLD) {
if (degree < LONG_TEXT_THRESHOLD) {
return false;
@@ -117,8 +38,8 @@ public class SatisfactionChecker {
} else if (degree < SHORT_TEXT_THRESHOLD) {
return false;
}
log.info("queryMode:{}, degree:{}, detectWords:{}, parse info:{}",
semanticParseInfo.getQueryMode(), degree, detectWordsDistinct, semanticParseInfo);
log.info("queryMode:{}, degree:{}, parse info:{}",
semanticParseInfo.getQueryMode(), degree, semanticParseInfo);
return true;
}

View File

@@ -1,13 +1,18 @@
package com.tencent.supersonic.chat.parser.embedding;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
@@ -17,18 +22,20 @@ import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.*;
import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
@Slf4j
public class EmbeddingBasedParser implements SemanticParser {
private final static double THRESHOLD = 0.2d;
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
@@ -36,109 +43,15 @@ public class EmbeddingBasedParser implements SemanticParser {
return;
}
log.info("EmbeddingBasedParser parser query ctx: {}, chat ctx: {}", queryContext, chatContext);
Set<Long> domainIds = getDomainMatched(queryContext);
String text = queryContext.getRequest().getQueryText();
if (!CollectionUtils.isEmpty(domainIds)) {
for (Long domainId : domainIds) {
List<SchemaElementMatch> schemaElementMatches = getMatchedElements(queryContext, domainId);
String textReplaced = replaceText(text, schemaElementMatches);
List<RecallRetrieval> embeddingRetrievals = recallResult(textReplaced, hasCandidateQuery(queryContext));
Optional<Plugin> pluginOptional = choosePlugin(embeddingRetrievals, domainId);
log.info("domain id :{} embedding result, text:{} embeddingResp:{} ",domainId, textReplaced, embeddingRetrievals);
pluginOptional.ifPresent(plugin -> buildQuery(plugin, embeddingRetrievals, domainId, textReplaced, queryContext, schemaElementMatches));
}
} else {
List<RecallRetrieval> embeddingRetrievals = recallResult(text, hasCandidateQuery(queryContext));
Optional<Plugin> pluginOptional = choosePlugin(embeddingRetrievals, null);
pluginOptional.ifPresent(plugin -> buildQuery(plugin, embeddingRetrievals, null, text, queryContext, Lists.newArrayList()));
}
List<RecallRetrieval> embeddingRetrievals = recallResult(text);
choosePlugin(embeddingRetrievals, queryContext);
}
private void buildQuery(Plugin plugin, List<RecallRetrieval> embeddingRetrievals,
Long domainId, String text,
QueryContext queryContext, List<SchemaElementMatch> schemaElementMatches) {
Map<String, RecallRetrieval> embeddingRetrievalMap = embeddingRetrievals.stream()
.collect(Collectors.toMap(RecallRetrieval::getId, e -> e, (value1, value2) -> value1));
log.info("EmbeddingBasedParser text: {} domain: {} choose plugin: [{} {}]",
text, domainId, plugin.getId(), plugin.getName());
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(domainId, plugin, text,
queryContext.getRequest(), embeddingRetrievalMap, schemaElementMatches);
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
pluginQuery.setParseInfo(semanticParseInfo);
queryContext.getCandidateQueries().add(pluginQuery);
}
private Set<Long> getDomainMatched(QueryContext queryContext) {
Long queryDomainId = queryContext.getRequest().getDomainId();
if (queryDomainId != null && queryDomainId > 0) {
return Sets.newHashSet(queryDomainId);
}
return queryContext.getMapInfo().getMatchedDomains();
}
private SemanticParseInfo buildSemanticParseInfo(Long domainId, Plugin plugin, String text, QueryReq queryReq,
Map<String, RecallRetrieval> embeddingRetrievalMap,
List<SchemaElementMatch> schemaElementMatches) {
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDomain(domainId);
schemaElement.setId(domainId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setDomain(schemaElement);
double distance = Double.parseDouble(embeddingRetrievalMap.get(plugin.getId().toString()).getDistance());
double score = text.length() * (1 - distance);
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);
pluginParseResult.setRequest(queryReq);
pluginParseResult.setDistance(distance);
properties.put(Constants.CONTEXT, pluginParseResult);
semanticParseInfo.setProperties(properties);
semanticParseInfo.setScore(score);
fillSemanticParseInfo(semanticParseInfo);
setEntity(domainId, semanticParseInfo);
return semanticParseInfo;
}
private List<SchemaElementMatch> getMatchedElements(QueryContext queryContext, Long domainId) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(domainId);
if (schemaElementMatches == null) {
return Lists.newArrayList();
}
QueryReq queryReq = queryContext.getRequest();
QueryFilters queryFilters = queryReq.getQueryFilters();
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return schemaElementMatches;
}
Map<Long, Object> element = queryFilters.getFilters().stream()
.collect(Collectors.toMap(QueryFilter::getElementID, QueryFilter::getValue, (v1, v2) -> v1));
return schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
.filter(schemaElementMatch ->
!element.containsKey(schemaElementMatch.getElement().getId()) || (
element.containsKey(schemaElementMatch.getElement().getId()) &&
element.get(schemaElementMatch.getElement().getId()).toString()
.equalsIgnoreCase(schemaElementMatch.getWord())
))
.collect(Collectors.toList());
}
private void setEntity(Long domainId, SemanticParseInfo semanticParseInfo) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = semanticService.getDomainSchema(domainId);
if (domainSchema != null && domainSchema.getEntity() != null) {
semanticParseInfo.setEntity(domainSchema.getEntity());
}
}
private Optional<Plugin> choosePlugin(List<RecallRetrieval> embeddingRetrievals,
Long domainId) {
private void choosePlugin(List<RecallRetrieval> embeddingRetrievals,
QueryContext queryContext) {
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return Optional.empty();
return;
}
PluginService pluginService = ContextUtils.getBean(PluginService.class);
List<Plugin> plugins = pluginService.getPluginList();
@@ -148,27 +61,92 @@ public class EmbeddingBasedParser implements SemanticParser {
if (plugin == null) {
continue;
}
if (domainId == null) {
return Optional.of(plugin);
}
if (!CollectionUtils.isEmpty(plugin.getDomainList()) && plugin.getDomainList().contains(domainId)) {
return Optional.of(plugin);
Pair<Boolean, List<Long>> pair = PluginManager.resolve(plugin, queryContext);
log.info("embedding plugin resolve: {}", pair);
if (pair.getLeft()) {
List<Long> modelList = pair.getRight();
if (CollectionUtils.isEmpty(modelList)) {
return;
}
modelList = distinctModelList(plugin, queryContext.getMapInfo(), modelList);
for (Long modelId : modelList) {
buildQuery(plugin, Double.parseDouble(embeddingRetrieval.getDistance()), modelId, queryContext,
queryContext.getMapInfo().getMatchedElements(modelId));
}
return;
}
}
return Optional.empty();
}
public List<RecallRetrieval> recallResult(String embeddingText, boolean hasCandidateQuery) {
public List<Long> distinctModelList(Plugin plugin, SchemaMapInfo schemaMapInfo, List<Long> modelList) {
if (!plugin.isContainsAllModel()) {
return modelList;
}
boolean noElementMatch = true;
for (Long model : modelList) {
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(model);
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
noElementMatch = false;
}
}
if (noElementMatch) {
return modelList.subList(0, 1);
}
return modelList;
}
private void buildQuery(Plugin plugin, double distance, Long modelId,
QueryContext queryContext, List<SchemaElementMatch> schemaElementMatches) {
log.info("EmbeddingBasedParser Model: {} choose plugin: [{} {}]", modelId, plugin.getId(), plugin.getName());
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin, queryContext.getRequest(),
schemaElementMatches, distance);
double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
semanticParseInfo.setScore(score);
pluginQuery.setParseInfo(semanticParseInfo);
queryContext.getCandidateQueries().add(pluginQuery);
}
private SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryReq queryReq,
List<SchemaElementMatch> schemaElementMatches, double distance) {
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
modelId = plugin.getModelList().get(0);
}
SchemaElement Model = new SchemaElement();
Model.setModel(modelId);
Model.setId(modelId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setModel(Model);
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);
pluginParseResult.setRequest(queryReq);
pluginParseResult.setDistance(distance);
properties.put(Constants.CONTEXT, pluginParseResult);
semanticParseInfo.setProperties(properties);
semanticParseInfo.setScore(distance);
fillSemanticParseInfo(semanticParseInfo);
setEntity(modelId, semanticParseInfo);
return semanticParseInfo;
}
private void setEntity(Long modelId, SemanticParseInfo semanticParseInfo) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema ModelSchema = semanticService.getModelSchema(modelId);
if (ModelSchema != null && ModelSchema.getEntity() != null) {
semanticParseInfo.setEntity(ModelSchema.getEntity());
}
}
public List<RecallRetrieval> recallResult(String embeddingText) {
try {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
EmbeddingResp embeddingResp = pluginManager.recognize(embeddingText);
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
if(!CollectionUtils.isEmpty(embeddingRetrievals)){
if (hasCandidateQuery) {
embeddingRetrievals = embeddingRetrievals.stream()
.filter(llmRetrieval -> Double.parseDouble(llmRetrieval.getDistance())<THRESHOLD)
.collect(Collectors.toList());
}
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o ->
Math.abs(Double.parseDouble(o.getDistance())))).collect(Collectors.toList());
embeddingResp.setRetrieval(embeddingRetrievals);
@@ -180,45 +158,22 @@ public class EmbeddingBasedParser implements SemanticParser {
return Lists.newArrayList();
}
private boolean hasCandidateQuery(QueryContext queryContext) {
return !CollectionUtils.isEmpty(queryContext.getCandidateQueries());
}
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.forEach(schemaElementMatch -> {
QueryFilter queryFilter = new QueryFilter();
queryFilter.setValue(schemaElementMatch.getWord());
queryFilter.setElementID(schemaElementMatch.getElement().getId());
queryFilter.setName(schemaElementMatch.getElement().getName());
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
semanticParseInfo.getDimensionFilters().add(queryFilter);
QueryFilter queryFilter = new QueryFilter();
queryFilter.setValue(schemaElementMatch.getWord());
queryFilter.setElementID(schemaElementMatch.getElement().getId());
queryFilter.setName(schemaElementMatch.getElement().getName());
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
semanticParseInfo.getDimensionFilters().add(queryFilter);
});
}
}
protected String replaceText(String text, List<SchemaElementMatch> schemaElementMatches) {
if (CollectionUtils.isEmpty(schemaElementMatches)) {
return text;
}
List<SchemaElementMatch> valueSchemaElementMatches = schemaElementMatches.stream()
.filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElementMatches) {
String detectWord = schemaElementMatch.getDetectWord();
if (StringUtils.isBlank(detectWord)) {
continue;
}
text = text.replace(detectWord, "");
}
return text;
}
}

View File

@@ -1,23 +1,26 @@
package com.tencent.supersonic.chat.parser.embedding;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.service.ConfigService;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@Slf4j
@Component("EmbeddingEntityResolver")
public class EmbeddingEntityResolver {
private ConfigService configService;
public EmbeddingEntityResolver(ConfigService configService) {
@@ -25,18 +28,19 @@ public class EmbeddingEntityResolver {
}
private Long getEntityValue(Long domainId, Long entityElementId, QueryContext queryCtx, ChatContext chatCtx) {
private Long getEntityValue(Long modelId, Long entityElementId, QueryContext queryCtx, ChatContext chatCtx) {
Long entityId = null;
QueryFilters queryFilters = queryCtx.getRequest().getQueryFilters();
if (queryFilters != null) {
entityId = getEntityValueFromQueryFilter(queryFilters.getFilters());
if (entityId != null) {
log.info("get entity id:{} domain id:{} from query filter :{} ", entityId, domainId, queryFilters);
log.info("get entity id:{} model id:{} from query filter :{} ", entityId, modelId, queryFilters);
return entityId;
}
}
entityId = getEntityValueFromSchemaMapInfo(domainId, queryCtx.getMapInfo(), entityElementId);
log.info("get entity id:{} from schema map Info :{} ", entityId, JSONObject.toJSONString(queryCtx.getMapInfo()));
entityId = getEntityValueFromSchemaMapInfo(modelId, queryCtx.getMapInfo(), entityElementId);
log.info("get entity id:{} from schema map Info :{} ", entityId,
JSONObject.toJSONString(queryCtx.getMapInfo()));
if (entityId == null || entityId == 0) {
Long entityIdFromChat = getEntityValueFromParseInfo(chatCtx.getParseInfo(), entityElementId);
if (entityIdFromChat != null && entityIdFromChat > 0) {
@@ -75,8 +79,8 @@ public class EmbeddingEntityResolver {
}
private Long getEntityValueFromSchemaMapInfo(Long domainId, SchemaMapInfo schemaMapInfo, Long entityElementId) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(domainId);
private Long getEntityValueFromSchemaMapInfo(Long modelId, SchemaMapInfo schemaMapInfo, Long entityElementId) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
return null;
}

View File

@@ -1,9 +1,8 @@
package com.tencent.supersonic.chat.parser.embedding;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class EmbeddingResp {

View File

@@ -7,18 +7,20 @@ import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.config.FunctionCallInfoConfig;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.plugin.dsl.DSLQuery;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
@@ -29,10 +31,9 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
@@ -55,21 +56,12 @@ public class FunctionBasedParser implements SemanticParser {
FunctionCallInfoConfig functionCallConfig = ContextUtils.getBean(FunctionCallInfoConfig.class);
PluginService pluginService = ContextUtils.getBean(PluginService.class);
String functionUrl = functionCallConfig.getUrl();
if (StringUtils.isBlank(functionUrl) || SatisfactionChecker.check(queryCtx)) {
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
queryCtx.getRequest().getQueryText());
return;
}
Set<Long> matchedDomains = getMatchDomains(queryCtx);
List<String> functionNames = getFunctionNames(matchedDomains);
log.info("matchedDomains:{},functionNames:{}", matchedDomains, functionNames);
if (CollectionUtils.isEmpty(functionNames) || CollectionUtils.isEmpty(matchedDomains)) {
return;
}
List<PluginParseConfig> functionDOList = getFunctionDO(queryCtx.getRequest().getDomainId());
List<PluginParseConfig> functionDOList = getFunctionDO(queryCtx.getRequest().getModelId(), queryCtx);
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryCtx.getRequest().getQueryText())
.pluginConfigs(functionDOList).build();
@@ -78,7 +70,6 @@ public class FunctionBasedParser implements SemanticParser {
if (skipFunction(queryCtx, functionResp)) {
return;
}
PluginParseResult functionCallParseResult = new PluginParseResult();
String toolSelection = functionResp.getToolSelection();
Optional<Plugin> pluginOptional = pluginService.getPluginByName(toolSelection);
@@ -87,24 +78,26 @@ public class FunctionBasedParser implements SemanticParser {
return;
}
Plugin plugin = pluginOptional.get();
plugin.setParseMode(ParseMode.FUNCTION_CALL);
toolSelection = plugin.getType();
functionCallParseResult.setPlugin(plugin);
log.info("QueryManager PluginQueryModes:{}", QueryManager.getPluginQueryModes());
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(toolSelection);
DomainResolver domainResolver = ComponentFactory.getDomainResolver();
Long domainId = domainResolver.resolve(queryCtx, chatCtx, plugin.getDomainList());
log.info("FunctionBasedParser domainId:{}",domainId);
if ((Objects.isNull(domainId) || domainId <= 0) && !plugin.isContainsAllDomain()) {
log.info("domain is null, skip the parse, select tool: {}", toolSelection);
ModelResolver ModelResolver = ComponentFactory.getModelResolver();
log.info("plugin ModelList:{}", plugin.getModelList());
Pair<Boolean, List<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx);
Long modelId = ModelResolver.resolve(queryCtx, chatCtx, pluginResolveResult.getRight());
log.info("FunctionBasedParser modelId:{}", modelId);
if ((Objects.isNull(modelId) || modelId <= 0) && !plugin.isContainsAllModel()) {
log.info("Model is null, skip the parse, select tool: {}", toolSelection);
return;
}
if (!plugin.getDomainList().contains(domainId) && !plugin.isContainsAllDomain()) {
if (!plugin.getModelList().contains(modelId) && !plugin.isContainsAllModel()) {
return;
}
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
if (Objects.nonNull(domainId) && domainId > 0){
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(domainId));
if (Objects.nonNull(modelId) && modelId > 0) {
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
}
functionCallParseResult.setRequest(queryCtx.getRequest());
Map<String, Object> properties = new HashMap<>();
@@ -112,21 +105,22 @@ public class FunctionBasedParser implements SemanticParser {
parseInfo.setProperties(properties);
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
parseInfo.setQueryMode(semanticQuery.getQueryMode());
SchemaElement domain = new SchemaElement();
domain.setDomain(domainId);
domain.setId(domainId);
parseInfo.setDomain(domain);
SchemaElement Model = new SchemaElement();
Model.setModel(modelId);
Model.setId(modelId);
parseInfo.setModel(Model);
queryCtx.getCandidateQueries().add(semanticQuery);
}
private Set<Long> getMatchDomains(QueryContext queryCtx) {
private Set<Long> getMatchModels(QueryContext queryCtx) {
Set<Long> result = new HashSet<>();
Long domainId = queryCtx.getRequest().getDomainId();
if (Objects.nonNull(domainId) && domainId > 0) {
result.add(domainId);
Long modelId = queryCtx.getRequest().getModelId();
if (Objects.nonNull(modelId) && modelId > 0) {
result.add(modelId);
return result;
}
return queryCtx.getMapInfo().getMatchedDomains();
return queryCtx.getMapInfo().getMatchedModels();
}
private boolean skipFunction(QueryContext queryCtx, FunctionResp functionResp) {
@@ -144,36 +138,46 @@ public class FunctionBasedParser implements SemanticParser {
return false;
}
private List<PluginParseConfig> getFunctionDO(Long domainId) {
log.info("user decide domain:{}", domainId);
private List<PluginParseConfig> getFunctionDO(Long modelId, QueryContext queryContext) {
log.info("user decide Model:{}", modelId);
List<Plugin> plugins = PluginManager.getPlugins();
List<PluginParseConfig> functionDOList = plugins.stream().filter(o -> {
if (o.getParseModeConfig() == null) {
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
if (DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
return false;
}
if (!CollectionUtils.isEmpty(o.getDomainList())) {//过滤掉没选主题域的插件
return true;
if (plugin.getParseModeConfig() == null) {
return false;
}
if (domainId == null || domainId <= 0L) {
return true;
} else {
return o.getDomainList().contains(domainId);
}
}).map(o -> {
PluginParseConfig functionCallConfig = JsonUtil.toObject(o.getParseModeConfig(),
PluginParseConfig pluginParseConfig = JsonUtil.toObject(plugin.getParseModeConfig(),
PluginParseConfig.class);
return functionCallConfig;
}).collect(Collectors.toList());
if (StringUtils.isBlank(pluginParseConfig.getName())) {
return false;
}
Pair<Boolean, List<Long>> pluginResolverResult = PluginManager.resolve(plugin, queryContext);
log.info("embedding plugin [{}-{}] resolve: {}", plugin.getId(), plugin.getName(), pluginResolverResult);
if (!pluginResolverResult.getLeft()) {
return false;
} else {
List<Long> resolveModel = pluginResolverResult.getRight();
if (modelId != null && modelId > 0) {
if (plugin.isContainsAllModel()) {
return true;
}
return resolveModel.contains(modelId);
}
return true;
}
}).map(o -> JsonUtil.toObject(o.getParseModeConfig(), PluginParseConfig.class)).collect(Collectors.toList());
log.info("getFunctionDO:{}", JsonUtil.toString(functionDOList));
return functionDOList;
}
private List<String> getFunctionNames(Set<Long> matchedDomains) {
private List<String> getFunctionNames(Set<Long> matchedModels) {
List<Plugin> plugins = PluginManager.getPlugins();
Set<String> functionNames = plugins.stream()
.filter(entry -> {
if (!CollectionUtils.isEmpty(entry.getDomainList()) && !CollectionUtils.isEmpty(matchedDomains)) {
return entry.getDomainList().stream().anyMatch(matchedDomains::contains);
if (!CollectionUtils.isEmpty(entry.getModelList()) && !CollectionUtils.isEmpty(matchedModels)) {
return entry.getModelList().stream().anyMatch(matchedModels::contains);
}
return true;
}

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.chat.parser.function;
import java.util.List;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import java.util.List;
import lombok.Builder;
import lombok.Data;

View File

@@ -1,181 +0,0 @@
package com.tencent.supersonic.chat.parser.function;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import lombok.extern.slf4j.Slf4j;
import java.util.*;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class HeuristicDomainResolver implements DomainResolver {
protected static Long selectDomainBySchemaElementCount(Map<Long, SemanticQuery> domainQueryModes,
SchemaMapInfo schemaMap) {
Map<Long, DomainMatchResult> domainTypeMap = getDomainTypeMap(schemaMap);
if (domainTypeMap.size() == 1) {
Long domainSelect = domainTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
if (domainQueryModes.containsKey(domainSelect)) {
log.info("selectDomain with only one domain [{}]", domainSelect);
return domainSelect;
}
} else {
Map.Entry<Long, DomainMatchResult> maxDomain = domainTypeMap.entrySet().stream()
.filter(entry -> domainQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> {
int difference = o2.getValue().getCount() - o1.getValue().getCount();
if (difference == 0) {
return (int) ((o2.getValue().getMaxSimilarity()
- o1.getValue().getMaxSimilarity()) * 100);
}
return difference;
}).findFirst().orElse(null);
if (maxDomain != null) {
log.info("selectDomain with multiple domains [{}]", maxDomain.getKey());
return maxDomain.getKey();
}
}
return 0L;
}
/**
* to check can switch domain if context exit domain
*
* @return false will use context domain, true will use other domain , maybe include context domain
*/
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> domainQueryModes, SchemaMapInfo schemaMap,
ChatContext chatCtx, QueryReq searchCtx, Long domainId, List<Long> restrictiveDomains) {
if (!Objects.nonNull(domainId) || domainId <= 0) {
return true;
}
// except content domain, calculate the number of types for each domain, if numbers<=1 will not switch
Map<Long, DomainMatchResult> domainTypeMap = getDomainTypeMap(schemaMap);
log.info("isAllowSwitch domainTypeMap [{}]", domainTypeMap);
long otherDomainTypeNumBigOneCount = domainTypeMap.entrySet().stream()
.filter(entry -> domainQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(domainId))
.filter(entry -> entry.getValue().getCount() > 1).count();
if (otherDomainTypeNumBigOneCount >= 1) {
return true;
}
// if query text only contain time , will not switch
if (!CollectionUtils.isEmpty(domainQueryModes.values())) {
for (SemanticQuery semanticQuery : domainQueryModes.values()) {
if (semanticQuery == null) {
continue;
}
SemanticParseInfo semanticParseInfo = semanticQuery.getParseInfo();
if (semanticParseInfo == null) {
continue;
}
if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) {
if (semanticParseInfo.getDateInfo().getText() != null) {
if (semanticParseInfo.getDateInfo().getText().equalsIgnoreCase(searchCtx.getQueryText())) {
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
semanticParseInfo.getDateInfo());
return false;
}
}
}
}
}
if (CollectionUtils.isNotEmpty(restrictiveDomains) && !restrictiveDomains.contains(domainId)) {
return true;
}
// if context domain not in schemaMap , will switch
if (schemaMap.getMatchedElements(domainId) == null || schemaMap.getMatchedElements(domainId).size() <= 0) {
log.info("domainId not in schemaMap ");
return true;
}
// other will not switch
return false;
}
public static Map<Long, DomainMatchResult> getDomainTypeMap(SchemaMapInfo schemaMap) {
Map<Long, DomainMatchResult> domainCount = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getDomainElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!domainCount.containsKey(entry.getKey())) {
domainCount.put(entry.getKey(), new DomainMatchResult());
}
DomainMatchResult domainMatchResult = domainCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream()
.forEach(schemaElementMatch -> schemaElementTypes.add(
schemaElementMatch.getElement().getType()));
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
.sorted((o1, o2) ->
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
).findFirst().orElse(null);
if (schemaElementMatchMax != null) {
domainMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
}
domainMatchResult.setCount(schemaElementTypes.size());
}
}
return domainCount;
}
public Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveDomains) {
Long domainId = queryContext.getRequest().getDomainId();
if (Objects.nonNull(domainId) && domainId > 0) {
if (CollectionUtils.isNotEmpty(restrictiveDomains) && restrictiveDomains.contains(domainId)) {
return domainId;
} else {
return null;
}
}
SchemaMapInfo mapInfo = queryContext.getMapInfo();
Set<Long> matchedDomains = mapInfo.getMatchedDomains();
if (CollectionUtils.isNotEmpty(restrictiveDomains)) {
matchedDomains = matchedDomains.stream()
.filter(matchedDomain -> restrictiveDomains.contains(matchedDomain))
.collect(Collectors.toSet());
}
Map<Long, SemanticQuery> domainQueryModes = new HashMap<>();
for (Long matchedDomain : matchedDomains) {
domainQueryModes.put(matchedDomain, null);
}
if(domainQueryModes.size()==1){
return domainQueryModes.keySet().stream().findFirst().get();
}
return resolve(domainQueryModes, queryContext, chatCtx,
queryContext.getMapInfo(),restrictiveDomains);
}
public Long resolve(Map<Long, SemanticQuery> domainQueryModes, QueryContext queryContext,
ChatContext chatCtx, SchemaMapInfo schemaMap, List<Long> restrictiveDomains) {
Long selectDomain = selectDomain(domainQueryModes, queryContext.getRequest(), chatCtx, schemaMap,restrictiveDomains);
if (selectDomain > 0) {
log.info("selectDomain {} ", selectDomain);
return selectDomain;
}
// get the max SchemaElementType number
return selectDomainBySchemaElementCount(domainQueryModes, schemaMap);
}
public Long selectDomain(Map<Long, SemanticQuery> domainQueryModes, QueryReq queryContext,
ChatContext chatCtx,
SchemaMapInfo schemaMap, List<Long> restrictiveDomains) {
// if QueryContext has domainId and in domainQueryModes
if (domainQueryModes.containsKey(queryContext.getDomainId())) {
log.info("selectDomain from QueryContext [{}]", queryContext.getDomainId());
return queryContext.getDomainId();
}
// if ChatContext has domainId and in domainQueryModes
if (chatCtx.getParseInfo().getDomainId() > 0) {
Long domainId = chatCtx.getParseInfo().getDomainId();
if (!isAllowSwitch(domainQueryModes, schemaMap, chatCtx, queryContext, domainId,restrictiveDomains)) {
log.info("selectDomain from ChatContext [{}]", domainId);
return domainId;
}
}
// default 0
return 0L;
}
}

View File

@@ -0,0 +1,192 @@
package com.tencent.supersonic.chat.parser.function;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class HeuristicModelResolver implements ModelResolver {
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> ModelQueryModes,
SchemaMapInfo schemaMap) {
Map<Long, ModelMatchResult> ModelTypeMap = getModelTypeMap(schemaMap);
if (ModelTypeMap.size() == 1) {
Long ModelSelect = ModelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
if (ModelQueryModes.containsKey(ModelSelect)) {
log.info("selectModel with only one Model [{}]", ModelSelect);
return ModelSelect;
}
} else {
Map.Entry<Long, ModelMatchResult> maxModel = ModelTypeMap.entrySet().stream()
.filter(entry -> ModelQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> {
int difference = o2.getValue().getCount() - o1.getValue().getCount();
if (difference == 0) {
return (int) ((o2.getValue().getMaxSimilarity()
- o1.getValue().getMaxSimilarity()) * 100);
}
return difference;
}).findFirst().orElse(null);
if (maxModel != null) {
log.info("selectModel with multiple Models [{}]", maxModel.getKey());
return maxModel.getKey();
}
}
return 0L;
}
/**
* to check can switch Model if context exit Model
*
* @return false will use context Model, true will use other Model , maybe include context Model
*/
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> ModelQueryModes, SchemaMapInfo schemaMap,
ChatContext chatCtx, QueryReq searchCtx, Long modelId, List<Long> restrictiveModels) {
if (!Objects.nonNull(modelId) || modelId <= 0) {
return true;
}
// except content Model, calculate the number of types for each Model, if numbers<=1 will not switch
Map<Long, ModelMatchResult> ModelTypeMap = getModelTypeMap(schemaMap);
log.info("isAllowSwitch ModelTypeMap [{}]", ModelTypeMap);
long otherModelTypeNumBigOneCount = ModelTypeMap.entrySet().stream()
.filter(entry -> ModelQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(modelId))
.filter(entry -> entry.getValue().getCount() > 1).count();
if (otherModelTypeNumBigOneCount >= 1) {
return true;
}
// if query text only contain time , will not switch
if (!CollectionUtils.isEmpty(ModelQueryModes.values())) {
for (SemanticQuery semanticQuery : ModelQueryModes.values()) {
if (semanticQuery == null) {
continue;
}
SemanticParseInfo semanticParseInfo = semanticQuery.getParseInfo();
if (semanticParseInfo == null) {
continue;
}
if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) {
if (semanticParseInfo.getDateInfo().getDetectWord() != null) {
if (semanticParseInfo.getDateInfo().getDetectWord()
.equalsIgnoreCase(searchCtx.getQueryText())) {
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
semanticParseInfo.getDateInfo());
return false;
}
}
}
}
}
if (CollectionUtils.isNotEmpty(restrictiveModels) && !restrictiveModels.contains(modelId)) {
return true;
}
// if context Model not in schemaMap , will switch
if (schemaMap.getMatchedElements(modelId) == null || schemaMap.getMatchedElements(modelId).size() <= 0) {
log.info("modelId not in schemaMap ");
return true;
}
// other will not switch
return false;
}
public static Map<Long, ModelMatchResult> getModelTypeMap(SchemaMapInfo schemaMap) {
Map<Long, ModelMatchResult> ModelCount = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!ModelCount.containsKey(entry.getKey())) {
ModelCount.put(entry.getKey(), new ModelMatchResult());
}
ModelMatchResult ModelMatchResult = ModelCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream()
.forEach(schemaElementMatch -> schemaElementTypes.add(
schemaElementMatch.getElement().getType()));
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
.sorted((o1, o2) ->
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
).findFirst().orElse(null);
if (schemaElementMatchMax != null) {
ModelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
}
ModelMatchResult.setCount(schemaElementTypes.size());
}
}
return ModelCount;
}
public Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveModels) {
Long modelId = queryContext.getRequest().getModelId();
if (Objects.nonNull(modelId) && modelId > 0) {
if (CollectionUtils.isNotEmpty(restrictiveModels) && restrictiveModels.contains(modelId)) {
return modelId;
} else {
return null;
}
}
SchemaMapInfo mapInfo = queryContext.getMapInfo();
Set<Long> matchedModels = mapInfo.getMatchedModels();
if (CollectionUtils.isNotEmpty(restrictiveModels)) {
matchedModels = matchedModels.stream()
.filter(restrictiveModels::contains)
.collect(Collectors.toSet());
}
Map<Long, SemanticQuery> ModelQueryModes = new HashMap<>();
for (Long matchedModel : matchedModels) {
ModelQueryModes.put(matchedModel, null);
}
if (ModelQueryModes.size() == 1) {
return ModelQueryModes.keySet().stream().findFirst().get();
}
return resolve(ModelQueryModes, queryContext, chatCtx,
queryContext.getMapInfo(), restrictiveModels);
}
public Long resolve(Map<Long, SemanticQuery> ModelQueryModes, QueryContext queryContext,
ChatContext chatCtx, SchemaMapInfo schemaMap, List<Long> restrictiveModels) {
Long selectModel = selectModel(ModelQueryModes, queryContext.getRequest(), chatCtx, schemaMap,
restrictiveModels);
if (selectModel > 0) {
log.info("selectModel {} ", selectModel);
return selectModel;
}
// get the max SchemaElementType number
return selectModelBySchemaElementCount(ModelQueryModes, schemaMap);
}
public Long selectModel(Map<Long, SemanticQuery> ModelQueryModes, QueryReq queryContext,
ChatContext chatCtx,
SchemaMapInfo schemaMap, List<Long> restrictiveModels) {
// if QueryContext has modelId and in ModelQueryModes
if (ModelQueryModes.containsKey(queryContext.getModelId())) {
log.info("selectModel from QueryContext [{}]", queryContext.getModelId());
return queryContext.getModelId();
}
// if ChatContext has modelId and in ModelQueryModes
if (chatCtx.getParseInfo().getModelId() > 0) {
Long modelId = chatCtx.getParseInfo().getModelId();
if (!isAllowSwitch(ModelQueryModes, schemaMap, chatCtx, queryContext, modelId, restrictiveModels)) {
log.info("selectModel from ChatContext [{}]", modelId);
return modelId;
}
}
// default 0
return 0L;
}
}

View File

@@ -3,7 +3,8 @@ package com.tencent.supersonic.chat.parser.function;
import lombok.Data;
@Data
public class DomainMatchResult {
public class ModelMatchResult {
private Integer count = 0;
private double maxSimilarity;
}

View File

@@ -5,8 +5,8 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import java.util.List;
public interface DomainResolver {
public interface ModelResolver {
Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveDomains);
Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveModels);
}

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.parser.function;
import lombok.Data;
import java.util.List;
import java.util.Map;
import lombok.Data;
@Data
public class Parameters {

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.chat.parser.llm;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.dsl.LLMResp;
import lombok.Data;
@Data
public class DSLParseResult extends PluginParseResult {
private LLMResp llmResp;
}

View File

@@ -0,0 +1,231 @@
package com.tencent.supersonic.chat.parser.llm;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.config.LLMConfig;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.parser.function.ModelResolver;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.dsl.DSLBuilder;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import com.tencent.supersonic.chat.query.dsl.LLMReq;
import com.tencent.supersonic.chat.query.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.dsl.LLMResp;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.RestTemplate;
@Slf4j
public class LLMDSLParser implements SemanticParser {
public static final double FUNCTION_BONUS_THRESHOLD = 201;
@Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
final LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class);
if (StringUtils.isEmpty(llmConfig.getUrl()) || SatisfactionChecker.check(queryCtx)) {
log.info("llmConfig:{}, skip function parser, queryText:{}", llmConfig,
queryCtx.getRequest().getQueryText());
return;
}
List<Plugin> dslPlugins = PluginManager.getPlugins().stream()
.filter(plugin -> DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(dslPlugins)) {
return;
}
Plugin plugin = dslPlugins.get(0);
List<Long> dslModels = plugin.getModelList();
try {
ModelResolver modelResolver = ComponentFactory.getModelResolver();
Long modelId = modelResolver.resolve(queryCtx, chatCtx, dslModels);
log.info("resolve modelId:{},dslModels:{}", modelId, dslModels);
if (Objects.isNull(modelId)) {
return;
}
LLMResp llmResp = requestLLM(queryCtx, modelId);
if (Objects.isNull(llmResp)) {
return;
}
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DSLQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
if (Objects.nonNull(modelId) && modelId > 0) {
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
}
DSLParseResult dslParseResult = new DSLParseResult();
dslParseResult.setRequest(queryCtx.getRequest());
dslParseResult.setLlmResp(llmResp);
dslParseResult.setPlugin(plugin);
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
parseInfo.setProperties(properties);
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
parseInfo.setQueryMode(semanticQuery.getQueryMode());
SchemaElement Model = new SchemaElement();
Model.setModel(modelId);
Model.setId(modelId);
parseInfo.setModel(Model);
queryCtx.getCandidateQueries().add(semanticQuery);
} catch (Exception e) {
log.error("LLMDSLParser error", e);
}
}
private LLMResp requestLLM(QueryContext queryCtx, Long modelId) {
long startTime = System.currentTimeMillis();
String queryText = queryCtx.getRequest().getQueryText();
final LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class);
if (StringUtils.isEmpty(llmConfig.getUrl())) {
log.warn("llmConfig url is null, skip llm parser");
return null;
}
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmSchema.setModelName(modelIdToName.get(modelId));
llmSchema.setDomainName(modelIdToName.get(modelId));
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema);
fieldNameList.add(DSLBuilder.DATA_Field);
llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
llmReq.setLinking(linking);
String currentDate = getCurrentDate(modelId);
llmReq.setCurrentDate(currentDate);
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
String questUrl = llmConfig.getUrl() + llmConfig.getQueryToSqlPath();
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(questUrl, HttpMethod.POST, entity,
LLMResp.class);
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, questUrl, entity, responseEntity.getBody());
return responseEntity.getBody();
} catch (Exception e) {
log.error("requestLLM error", e);
}
return null;
}
private String getCurrentDate(Long modelId) {
return DateUtils.getBeforeDate(4);
// ChatConfigFilter filter = new ChatConfigFilter();
// filter.setModelId(modelId);
//
// List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
// if (CollectionUtils.isEmpty(configResps)) {
// return
// }
// ChatConfigResp chatConfigResp = configResps.get(0);
// chatConfigResp.getChatDetailConfig().getChatDefaultConfig().get
}
private List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
}
Set<ElementValue> valueMatches = matchedElements.stream()
.filter(schemaElementMatch -> {
SchemaElementType type = schemaElementMatch.getElement().getType();
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
})
.map(elementMatch ->
{
ElementValue elementValue = new ElementValue();
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
elementValue.setFieldValue(elementMatch.getWord());
return elementValue;
}
)
.collect(Collectors.toSet());
return new ArrayList<>(valueMatches);
}
private List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
}
Set<String> fieldNameList = matchedElements.stream()
.filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType) ||
SchemaElementType.DIMENSION.equals(elementType) ||
SchemaElementType.VALUE.equals(elementType);
})
.map(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
if (!SchemaElementType.VALUE.equals(elementType)) {
return schemaElementMatch.getWord();
}
return itemIdToName.get(schemaElementMatch.getElement().getId());
})
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
.collect(Collectors.toSet());
return new ArrayList<>(fieldNameList);
}
private Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
return semanticSchema.getDimensions().stream()
.filter(entry -> modelId.equals(entry.getModel()))
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}
}

View File

@@ -17,7 +17,7 @@ public class LLMTimeEnhancementParse implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
log.info("before queryContext:{},chatContext:{}",queryContext,chatContext);
log.info("before queryContext:{},chatContext:{}", queryContext, chatContext);
ChatGptHelper chatGptHelper = ContextUtils.getBean(ChatGptHelper.class);
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
try {
@@ -25,12 +25,12 @@ public class LLMTimeEnhancementParse implements SemanticParser {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
DateConf dateInfo = query.getParseInfo().getDateInfo();
JSONObject jsonObject = JSON.parseObject(inferredTime);
if (jsonObject.containsKey("date")){
if (jsonObject.containsKey("date")) {
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(jsonObject.getString("date"));
dateInfo.setEndDate(jsonObject.getString("date"));
query.getParseInfo().setDateInfo(dateInfo);
}else if (jsonObject.containsKey("start")){
} else if (jsonObject.containsKey("start")) {
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(jsonObject.getString("start"));
dateInfo.setEndDate(jsonObject.getString("end"));
@@ -38,11 +38,12 @@ public class LLMTimeEnhancementParse implements SemanticParser {
}
}
}
}catch (Exception exception){
log.error("{} parse error,this reason is:{}",LLMTimeEnhancementParse.class.getSimpleName(), (Object) exception.getStackTrace());
} catch (Exception exception) {
log.error("{} parse error,this reason is:{}", LLMTimeEnhancementParse.class.getSimpleName(),
(Object) exception.getStackTrace());
}
log.info("after queryContext:{},chatContext:{}",queryContext,chatContext);
log.info("after queryContext:{},chatContext:{}", queryContext, chatContext);
}

View File

@@ -21,7 +21,9 @@ import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class AggregateTypeParser implements SemanticParser {
@@ -35,36 +37,60 @@ public class AggregateTypeParser implements SemanticParser {
new AbstractMap.SimpleEntry<>(DISTINCT, Pattern.compile("(?i)(uv)")),
new AbstractMap.SimpleEntry<>(COUNT, Pattern.compile("(?i)(总数|pv)")),
new AbstractMap.SimpleEntry<>(NONE, Pattern.compile("(?i)(明细)"))
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,(k1,k2)->k2));
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
String queryText = queryContext.getRequest().getQueryText();
AggregateConf aggregateConf = resolveAggregateConf(queryText);
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
if (!AggregateTypeEnum.NONE.equals(semanticQuery.getParseInfo().getAggType())) {
continue;
}
String queryText = queryContext.getRequest().getQueryText();
semanticQuery.getParseInfo().setAggType(resolveAggregateType(queryText));
semanticQuery.getParseInfo().setAggType(aggregateConf.type);
int detectWordLength = 0;
if (StringUtils.isNotEmpty(aggregateConf.detectWord)) {
detectWordLength = aggregateConf.detectWord.length();
}
semanticQuery.getParseInfo().setScore(semanticQuery.getParseInfo().getScore() + detectWordLength);
}
}
public static AggregateTypeEnum resolveAggregateType(String queryText) {
public AggregateTypeEnum resolveAggregateType(String queryText) {
AggregateConf aggregateConf = resolveAggregateConf(queryText);
return aggregateConf.type;
}
private AggregateConf resolveAggregateConf(String queryText) {
Map<AggregateTypeEnum, Integer> aggregateCount = new HashMap<>(REGX_MAP.size());
Map<AggregateTypeEnum, String> aggregateWord = new HashMap<>(REGX_MAP.size());
for (Map.Entry<AggregateTypeEnum, Pattern> entry : REGX_MAP.entrySet()) {
Matcher matcher = entry.getValue().matcher(queryText);
int count = 0;
String detectWord = null;
while (matcher.find()) {
count++;
detectWord = matcher.group();
}
if (count > 0) {
aggregateCount.put(entry.getKey(), count);
aggregateWord.put(entry.getKey(), detectWord);
}
}
return aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue())
AggregateTypeEnum type = aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue())
.map(entry -> entry.getKey()).orElse(NONE);
String detectWord = aggregateWord.get(type);
return new AggregateConf(type, detectWord);
}
@AllArgsConstructor
class AggregateConf {
public AggregateTypeEnum type;
public String detectWord;
}
}

View File

@@ -1,5 +1,12 @@
package com.tencent.supersonic.chat.parser.rule;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
@@ -8,8 +15,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricDomainQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.AbstractMap;
import java.util.ArrayList;
@@ -21,8 +28,6 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
@Slf4j
public class ContextInheritParser implements SemanticParser {
@@ -31,7 +36,7 @@ public class ContextInheritParser implements SemanticParser {
new AbstractMap.SimpleEntry<>(DIMENSION, Arrays.asList(DIMENSION, VALUE)),
new AbstractMap.SimpleEntry<>(VALUE, Arrays.asList(VALUE, DIMENSION)),
new AbstractMap.SimpleEntry<>(ENTITY, Arrays.asList(ENTITY)),
new AbstractMap.SimpleEntry<>(DOMAIN, Arrays.asList(DOMAIN)),
new AbstractMap.SimpleEntry<>(MODEL, Arrays.asList(MODEL)),
new AbstractMap.SimpleEntry<>(ID, Arrays.asList(ID))
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
@@ -41,9 +46,9 @@ public class ContextInheritParser implements SemanticParser {
return;
}
Long domainId = chatContext.getParseInfo().getDomainId();
Long modelId = chatContext.getParseInfo().getModelId();
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo()
.getMatchedElements(domainId);
.getMatchedElements(modelId);
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
@@ -51,7 +56,7 @@ public class ContextInheritParser implements SemanticParser {
// mutual exclusive element types should not be inherited
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(chatContext.getParseInfo().getQueryMode());
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
match.setMode(SchemaElementMatch.MatchMode.INHERIT);
match.setInherited(true);
matchesToInherit.add(match);
}
}
@@ -59,11 +64,24 @@ public class ContextInheritParser implements SemanticParser {
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(domainId, chatContext);
query.fillParseInfo(modelId, queryContext, chatContext);
if (existSameQuery(query.getParseInfo().getModelId(), query.getQueryMode(), queryContext)) {
continue;
}
queryContext.getCandidateQueries().add(query);
}
}
private boolean existSameQuery(Long modelId, String queryMode, QueryContext queryContext) {
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
if (semanticQuery.getQueryMode().equals(queryMode)
&& semanticQuery.getParseInfo().getModelId().equals(modelId)) {
return true;
}
}
return false;
}
private boolean containsTypes(List<SchemaElementMatch> matches, SchemaElementType matchType,
RuleSemanticQuery ruleQuery) {
List<SchemaElementType> types = MUTUAL_EXCLUSIVE_MAP.get(matchType);
@@ -79,16 +97,18 @@ public class ContextInheritParser implements SemanticParser {
}
protected boolean shouldInherit(QueryContext queryContext, ChatContext chatContext) {
Long contextDomainId = chatContext.getParseInfo().getDomainId();
if (queryContext.getMapInfo().getMatchedElements(contextDomainId) == null) {
Long contextmodelId = chatContext.getParseInfo().getModelId();
// if map info doesn't contain the same Model of the context,
// no inheritance could be done
if (queryContext.getMapInfo().getMatchedElements(contextmodelId) == null) {
return false;
}
// if candidates have only one MetricDomain mode and context has value filter , count in context
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries().stream()
.filter(semanticQuery -> semanticQuery.getParseInfo().getDomainId().equals(contextDomainId)).collect(
// if candidates only have MetricModel mode, count in context
List<SemanticQuery> metricModelQueries = queryContext.getCandidateQueries().stream()
.filter(query -> query instanceof MetricModelQuery).collect(
Collectors.toList());
if (candidateQueries.size() == 1 && (candidateQueries.get(0) instanceof MetricDomainQuery)) {
if (metricModelQueries.size() == queryContext.getCandidateQueries().size()) {
return true;
} else {
return queryContext.getCandidateQueries().size() == 0;

View File

@@ -1,11 +1,12 @@
package com.tencent.supersonic.chat.parser.rule;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.*;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@@ -16,12 +17,12 @@ public class QueryModeParser implements SemanticParser {
SchemaMapInfo mapInfo = queryContext.getMapInfo();
// iterate all schemaElementMatches to resolve semantic query
for (Long domainId : mapInfo.getMatchedDomains()) {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(domainId);
for (Long modelId : mapInfo.getMatchedModels()) {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(modelId);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(domainId, chatContext);
query.fillParseInfo(modelId, queryContext, chatContext);
queryContext.getCandidateQueries().add(query);
}
}

View File

@@ -4,21 +4,21 @@ import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.xkzhangsan.time.nlp.TimeNLP;
import com.xkzhangsan.time.nlp.TimeNLPUtil;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.LocalDate;
import java.util.*;
import java.util.Date;
import java.util.List;
import java.util.Stack;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import com.xkzhangsan.time.nlp.TimeNLP;
import com.xkzhangsan.time.nlp.TimeNLPUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
@@ -45,12 +45,16 @@ public class TimeRangeParser implements SemanticParser {
if (queryContext.getCandidateQueries().size() > 0) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
query.getParseInfo().setDateInfo(dateConf);
query.getParseInfo().setScore(query.getParseInfo().getScore()
+ dateConf.getDetectWord().length());
}
} else if (QueryManager.containsRuleQuery(chatContext.getParseInfo().getQueryMode())) {
RuleSemanticQuery semanticQuery = QueryManager.createRuleQuery(
chatContext.getParseInfo().getQueryMode());
// inherit parse info from context
chatContext.getParseInfo().setDateInfo(dateConf);
chatContext.getParseInfo().setScore(chatContext.getParseInfo().getScore()
+ dateConf.getDetectWord().length());
semanticQuery.setParseInfo(chatContext.getParseInfo());
queryContext.getCandidateQueries().add(semanticQuery);
}
@@ -60,41 +64,48 @@ public class TimeRangeParser implements SemanticParser {
private DateConf parseDateCN(String queryText) {
Date startDate = null;
Date endDate;
String detectWord = null;
List<TimeNLP> times = TimeNLPUtil.parse(queryText);
if (times.size() > 0) {
startDate = times.get(0).getTime();
detectWord = times.get(0).getTimeExpression();
} else {
return null;
}
if (times.size() > 1) {
endDate = times.get(1).getTime();
detectWord += "~" + times.get(0).getTimeExpression();
} else {
endDate = startDate;
}
return getDateConf(startDate, endDate);
return getDateConf(startDate, endDate, detectWord);
}
private DateConf parseDateNumber(String queryText) {
String startDate;
String endDate = null;
String detectWord = null;
Matcher dateMatcher = DATE_PATTERN_NUMBER.matcher(queryText);
if (dateMatcher.find()) {
startDate = dateMatcher.group();
detectWord = startDate;
} else {
return null;
}
if (dateMatcher.find()) {
endDate = dateMatcher.group();
detectWord += "~" + endDate;
}
endDate = endDate != null ? endDate : startDate;
try {
return getDateConf(DATE_FORMAT_NUMBER.parse(startDate), DATE_FORMAT_NUMBER.parse(endDate));
return getDateConf(DATE_FORMAT_NUMBER.parse(startDate), DATE_FORMAT_NUMBER.parse(endDate), detectWord);
} catch (ParseException e) {
return null;
}
@@ -134,11 +145,11 @@ public class TimeRangeParser implements SemanticParser {
}
days = days * num;
info.setDateMode(DateConf.DateMode.RECENT);
String text = "" + num + zhPeriod;
String detectWord = "" + num + zhPeriod;
if (Strings.isNotEmpty(m.group("periodStr"))) {
text = m.group("periodStr");
detectWord = m.group("periodStr");
}
info.setText(text);
info.setDetectWord(detectWord);
info.setStartDate(LocalDate.now().minusDays(days).toString());
info.setUnit(num);
@@ -173,7 +184,7 @@ public class TimeRangeParser implements SemanticParser {
return stack.stream().mapToInt(s -> s).sum();
}
private DateConf getDateConf(Date startDate, Date endDate) {
private DateConf getDateConf(Date startDate, Date endDate, String detectWord) {
if (startDate == null || endDate == null) {
return null;
}
@@ -182,6 +193,7 @@ public class TimeRangeParser implements SemanticParser {
info.setDateMode(DateConf.DateMode.BETWEEN);
info.setStartDate(DATE_FORMAT.format(startDate));
info.setEndDate(DATE_FORMAT.format(endDate));
info.setDetectWord(detectWord);
return info;
}

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.Date;
import lombok.Data;
import lombok.ToString;
@@ -15,7 +14,7 @@ public class ChatConfigDO {
*/
private Long id;
private Long domainId;
private Long modelId;
private String chatDetailConfig;

View File

@@ -1,11 +1,10 @@
package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.ArrayList;
import java.util.List;
import com.tencent.supersonic.chat.config.DefaultMetric;
import com.tencent.supersonic.chat.config.Dim4Dict;
import java.util.ArrayList;
import java.util.List;
import lombok.Data;
import lombok.ToString;
@@ -14,14 +13,14 @@ import lombok.ToString;
@ToString
public class DimValueDO {
private Long domainId;
private Long modelId;
private List<DefaultMetric> defaultMetricDescList = new ArrayList<>();
private List<Dim4Dict> dimensions = new ArrayList<>();
public DimValueDO setDomainId(Long domainId) {
this.domainId = domainId;
public DimValueDO setModelId(Long modelId) {
this.modelId = modelId;
return this;
}

View File

@@ -3,8 +3,9 @@ package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.Date;
public class PluginDO {
/**
*
*
*/
private Long id;
@@ -14,71 +15,69 @@ public class PluginDO {
private String type;
/**
*
*
*/
private String domain;
private String model;
/**
*
*
*/
private String pattern;
/**
*
*
*/
private String parseMode;
/**
*
*
*/
private String name;
/**
*
*
*/
private Date createdAt;
/**
*
*
*/
private String createdBy;
/**
*
*
*/
private Date updatedAt;
/**
*
*
*/
private String updatedBy;
/**
*
*
*/
private String parseModeConfig;
/**
*
*
*/
private String config;
/**
*
*
*/
private String comment;
/**
*
* @return id
* @return id
*/
public Long getId() {
return id;
}
/**
*
* @param id
* @param id
*/
public void setId(Long id) {
this.id = id;
@@ -86,6 +85,7 @@ public class PluginDO {
/**
* DASHBOARD,WIDGET,URL
*
* @return type DASHBOARD,WIDGET,URL
*/
public String getType() {
@@ -94,6 +94,7 @@ public class PluginDO {
/**
* DASHBOARD,WIDGET,URL
*
* @param type DASHBOARD,WIDGET,URL
*/
public void setType(String type) {
@@ -101,176 +102,154 @@ public class PluginDO {
}
/**
*
* @return domain
* @return model
*/
public String getDomain() {
return domain;
public String getModel() {
return model;
}
/**
*
* @param domain
* @param model
*/
public void setDomain(String domain) {
this.domain = domain == null ? null : domain.trim();
public void setModel(String model) {
this.model = model == null ? null : model.trim();
}
/**
*
* @return pattern
* @return pattern
*/
public String getPattern() {
return pattern;
}
/**
*
* @param pattern
* @param pattern
*/
public void setPattern(String pattern) {
this.pattern = pattern == null ? null : pattern.trim();
}
/**
*
* @return parse_mode
* @return parse_mode
*/
public String getParseMode() {
return parseMode;
}
/**
*
* @param parseMode
* @param parseMode
*/
public void setParseMode(String parseMode) {
this.parseMode = parseMode == null ? null : parseMode.trim();
}
/**
*
* @return name
* @return name
*/
public String getName() {
return name;
}
/**
*
* @param name
* @param name
*/
public void setName(String name) {
this.name = name == null ? null : name.trim();
}
/**
*
* @return created_at
* @return created_at
*/
public Date getCreatedAt() {
return createdAt;
}
/**
*
* @param createdAt
* @param createdAt
*/
public void setCreatedAt(Date createdAt) {
this.createdAt = createdAt;
}
/**
*
* @return created_by
* @return created_by
*/
public String getCreatedBy() {
return createdBy;
}
/**
*
* @param createdBy
* @param createdBy
*/
public void setCreatedBy(String createdBy) {
this.createdBy = createdBy == null ? null : createdBy.trim();
}
/**
*
* @return updated_at
* @return updated_at
*/
public Date getUpdatedAt() {
return updatedAt;
}
/**
*
* @param updatedAt
* @param updatedAt
*/
public void setUpdatedAt(Date updatedAt) {
this.updatedAt = updatedAt;
}
/**
*
* @return updated_by
* @return updated_by
*/
public String getUpdatedBy() {
return updatedBy;
}
/**
*
* @param updatedBy
* @param updatedBy
*/
public void setUpdatedBy(String updatedBy) {
this.updatedBy = updatedBy == null ? null : updatedBy.trim();
}
/**
*
* @return parse_mode_config
* @return parse_mode_config
*/
public String getParseModeConfig() {
return parseModeConfig;
}
/**
*
* @param parseModeConfig
* @param parseModeConfig
*/
public void setParseModeConfig(String parseModeConfig) {
this.parseModeConfig = parseModeConfig == null ? null : parseModeConfig.trim();
}
/**
*
* @return config
* @return config
*/
public String getConfig() {
return config;
}
/**
*
* @param config
* @param config
*/
public void setConfig(String config) {
this.config = config == null ? null : config.trim();
}
/**
*
* @return comment
* @return comment
*/
public String getComment() {
return comment;
}
/**
*
* @param comment
* @param comment
*/
public void setComment(String comment) {
this.comment = comment == null ? null : comment.trim();

View File

@@ -5,6 +5,7 @@ import java.util.Date;
import java.util.List;
public class PluginDOExample {
/**
* s2_plugin
*/
@@ -31,7 +32,6 @@ public class PluginDOExample {
protected Integer limitEnd;
/**
*
* @mbg.generated
*/
public PluginDOExample() {
@@ -39,7 +39,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public void setOrderByClause(String orderByClause) {
@@ -47,7 +46,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public String getOrderByClause() {
@@ -55,7 +53,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public void setDistinct(boolean distinct) {
@@ -63,7 +60,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public boolean isDistinct() {
@@ -71,7 +67,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public List<Criteria> getOredCriteria() {
@@ -79,7 +74,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public void or(Criteria criteria) {
@@ -87,7 +81,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public Criteria or() {
@@ -97,7 +90,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public Criteria createCriteria() {
@@ -109,7 +101,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
protected Criteria createCriteriaInternal() {
@@ -118,7 +109,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public void clear() {
@@ -128,15 +118,13 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public void setLimitStart(Integer limitStart) {
this.limitStart=limitStart;
this.limitStart = limitStart;
}
/**
*
* @mbg.generated
*/
public Integer getLimitStart() {
@@ -144,15 +132,13 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public void setLimitEnd(Integer limitEnd) {
this.limitEnd=limitEnd;
this.limitEnd = limitEnd;
}
/**
*
* @mbg.generated
*/
public Integer getLimitEnd() {
@@ -163,6 +149,7 @@ public class PluginDOExample {
* s2_plugin null
*/
protected abstract static class GeneratedCriteria {
protected List<Criterion> criteria;
protected GeneratedCriteria() {
@@ -333,73 +320,73 @@ public class PluginDOExample {
return (Criteria) this;
}
public Criteria andDomainIsNull() {
addCriterion("domain is null");
public Criteria andModelIsNull() {
addCriterion("model is null");
return (Criteria) this;
}
public Criteria andDomainIsNotNull() {
addCriterion("domain is not null");
public Criteria andModelIsNotNull() {
addCriterion("model is not null");
return (Criteria) this;
}
public Criteria andDomainEqualTo(String value) {
addCriterion("domain =", value, "domain");
public Criteria andModelEqualTo(String value) {
addCriterion("model =", value, "model");
return (Criteria) this;
}
public Criteria andDomainNotEqualTo(String value) {
addCriterion("domain <>", value, "domain");
public Criteria andModelNotEqualTo(String value) {
addCriterion("model <>", value, "model");
return (Criteria) this;
}
public Criteria andDomainGreaterThan(String value) {
addCriterion("domain >", value, "domain");
public Criteria andModelGreaterThan(String value) {
addCriterion("model >", value, "model");
return (Criteria) this;
}
public Criteria andDomainGreaterThanOrEqualTo(String value) {
addCriterion("domain >=", value, "domain");
public Criteria andModelGreaterThanOrEqualTo(String value) {
addCriterion("model >=", value, "model");
return (Criteria) this;
}
public Criteria andDomainLessThan(String value) {
addCriterion("domain <", value, "domain");
public Criteria andModelLessThan(String value) {
addCriterion("model <", value, "model");
return (Criteria) this;
}
public Criteria andDomainLessThanOrEqualTo(String value) {
addCriterion("domain <=", value, "domain");
public Criteria andModelLessThanOrEqualTo(String value) {
addCriterion("model <=", value, "model");
return (Criteria) this;
}
public Criteria andDomainLike(String value) {
addCriterion("domain like", value, "domain");
public Criteria andModelLike(String value) {
addCriterion("model like", value, "model");
return (Criteria) this;
}
public Criteria andDomainNotLike(String value) {
addCriterion("domain not like", value, "domain");
public Criteria andModelNotLike(String value) {
addCriterion("model not like", value, "model");
return (Criteria) this;
}
public Criteria andDomainIn(List<String> values) {
addCriterion("domain in", values, "domain");
public Criteria andModelIn(List<String> values) {
addCriterion("model in", values, "model");
return (Criteria) this;
}
public Criteria andDomainNotIn(List<String> values) {
addCriterion("domain not in", values, "domain");
public Criteria andModelNotIn(List<String> values) {
addCriterion("model not in", values, "model");
return (Criteria) this;
}
public Criteria andDomainBetween(String value1, String value2) {
addCriterion("domain between", value1, value2, "domain");
public Criteria andModelBetween(String value1, String value2) {
addCriterion("model between", value1, value2, "model");
return (Criteria) this;
}
public Criteria andDomainNotBetween(String value1, String value2) {
addCriterion("domain not between", value1, value2, "domain");
public Criteria andModelNotBetween(String value1, String value2) {
addCriterion("model not between", value1, value2, "model");
return (Criteria) this;
}
@@ -888,6 +875,7 @@ public class PluginDOExample {
* s2_plugin null
*/
public static class Criterion {
private String condition;
private Object value;

View File

@@ -14,5 +14,5 @@ public interface ChatConfigMapper {
List<ChatConfigDO> search(ChatConfigFilterInternal filterInternal);
ChatConfigDO fetchConfigByDomainId(Long domainId);
ChatConfigDO fetchConfigByModelId(Long modelId);
}

View File

@@ -2,9 +2,8 @@ package com.tencent.supersonic.chat.persistence.mapper;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
import org.apache.ibatis.annotations.Mapper;
import java.util.List;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface ChatQueryDOMapper {

View File

@@ -2,67 +2,58 @@ package com.tencent.supersonic.chat.persistence.mapper;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDO;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
import org.apache.ibatis.annotations.Mapper;
import java.util.List;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface PluginDOMapper {
/**
*
* @mbg.generated
*/
long countByExample(PluginDOExample example);
/**
*
* @mbg.generated
*/
int deleteByPrimaryKey(Long id);
/**
*
* @mbg.generated
*/
int insert(PluginDO record);
/**
*
* @mbg.generated
*/
int insertSelective(PluginDO record);
/**
*
* @mbg.generated
*/
List<PluginDO> selectByExampleWithBLOBs(PluginDOExample example);
/**
*
* @mbg.generated
*/
List<PluginDO> selectByExample(PluginDOExample example);
/**
*
* @mbg.generated
*/
PluginDO selectByPrimaryKey(Long id);
/**
*
* @mbg.generated
*/
int updateByPrimaryKeySelective(PluginDO record);
/**
*
* @mbg.generated
*/
int updateByPrimaryKeyWithBLOBs(PluginDO record);
/**
*
* @mbg.generated
*/
int updateByPrimaryKey(PluginDO record);

View File

@@ -1,10 +1,9 @@
package com.tencent.supersonic.chat.persistence.repository;
import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.config.ChatConfig;
import java.util.List;
public interface ChatConfigRepository {
@@ -15,5 +14,5 @@ public interface ChatConfigRepository {
List<ChatConfigResp> getChatConfig(ChatConfigFilter filter);
ChatConfigResp getConfigByDomainId(Long domainId);
ChatConfigResp getConfigByModelId(Long modelId);
}

View File

@@ -2,11 +2,10 @@ package com.tencent.supersonic.chat.persistence.repository;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
public interface ChatQueryRepository {

View File

@@ -2,10 +2,10 @@ package com.tencent.supersonic.chat.persistence.repository;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDO;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
import java.util.List;
public interface PluginRepository {
List<PluginDO> getPlugins();
List<PluginDO> fetchPluginDOs(String queryText, String type);

View File

@@ -1,14 +1,13 @@
package com.tencent.supersonic.chat.persistence.repository.impl;
import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.config.ChatConfigFilterInternal;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.config.ChatConfigFilterInternal;
import com.tencent.supersonic.chat.persistence.dataobject.ChatConfigDO;
import com.tencent.supersonic.chat.persistence.mapper.ChatConfigMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatConfigRepository;
import com.tencent.supersonic.chat.utils.ChatConfigHelper;
import com.tencent.supersonic.chat.persistence.mapper.ChatConfigMapper;
import java.util.ArrayList;
import java.util.List;
import org.springframework.beans.BeanUtils;
@@ -24,7 +23,7 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
private final ChatConfigMapper chatConfigMapper;
public ChatConfigRepositoryImpl(ChatConfigHelper chatConfigHelper,
ChatConfigMapper chatConfigMapper) {
ChatConfigMapper chatConfigMapper) {
this.chatConfigHelper = chatConfigHelper;
this.chatConfigMapper = chatConfigMapper;
}
@@ -53,15 +52,16 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
List<ChatConfigDO> chaConfigDOList = chatConfigMapper.search(filterInternal);
if (!CollectionUtils.isEmpty(chaConfigDOList)) {
chaConfigDOList.stream().forEach(chaConfigDO ->
chaConfigDescriptorList.add(chatConfigHelper.chatConfigDO2Descriptor(chaConfigDO.getDomainId(), chaConfigDO)));
chaConfigDescriptorList.add(
chatConfigHelper.chatConfigDO2Descriptor(chaConfigDO.getModelId(), chaConfigDO)));
}
return chaConfigDescriptorList;
}
@Override
public ChatConfigResp getConfigByDomainId(Long domainId) {
ChatConfigDO chaConfigPO = chatConfigMapper.fetchConfigByDomainId(domainId);
return chatConfigHelper.chatConfigDO2Descriptor(domainId, chaConfigPO);
public ChatConfigResp getConfigByModelId(Long modelId) {
ChatConfigDO chaConfigPO = chatConfigMapper.fetchConfigByModelId(modelId);
return chatConfigHelper.chatConfigDO2Descriptor(modelId, chaConfigPO);
}
}

View File

@@ -3,13 +3,12 @@ package com.tencent.supersonic.chat.persistence.repository.impl;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample.Criteria;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.common.util.JsonUtil;

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.persistence.repository.impl;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.QueryDO;
import com.tencent.supersonic.chat.persistence.repository.ChatRepository;
import com.tencent.supersonic.chat.persistence.mapper.ChatMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatRepository;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Primary;

View File

@@ -5,13 +5,13 @@ import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
import com.tencent.supersonic.chat.persistence.mapper.PluginDOMapper;
import com.tencent.supersonic.chat.persistence.repository.PluginRepository;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.stereotype.Repository;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.stereotype.Repository;
@Repository
@Slf4j
@@ -58,7 +58,7 @@ public class PluginRepositoryImpl implements PluginRepository {
}
@Override
public void updatePlugin(PluginDO pluginDO){
public void updatePlugin(PluginDO pluginDO) {
pluginDOMapper.updateByPrimaryKeyWithBLOBs(pluginDO);
}
@@ -68,12 +68,12 @@ public class PluginRepositoryImpl implements PluginRepository {
}
@Override
public List<PluginDO> query(PluginDOExample pluginDOExample){
public List<PluginDO> query(PluginDOExample pluginDOExample) {
return pluginDOMapper.selectByExampleWithBLOBs(pluginDOExample);
}
@Override
public void deletePlugin(Long id){
public void deletePlugin(Long id) {
pluginDOMapper.deleteByPrimaryKey(id);
}

View File

@@ -5,10 +5,10 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.common.pojo.RecordInfo;
import java.util.List;
import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
@Data
public class Plugin extends RecordInfo {
@@ -20,7 +20,7 @@ public class Plugin extends RecordInfo {
*/
private String type;
private List<Long> domainList = Lists.newArrayList();
private List<Long> modelList = Lists.newArrayList();
/**
* description, for parsing
@@ -51,8 +51,8 @@ public class Plugin extends RecordInfo {
return Lists.newArrayList();
}
public boolean isContainsAllDomain() {
return CollectionUtils.isNotEmpty(domainList) && domainList.contains(-1L);
public boolean isContainsAllModel() {
return CollectionUtils.isNotEmpty(modelList) && modelList.contains(-1L);
}
}

View File

@@ -2,24 +2,44 @@ package com.tencent.supersonic.chat.plugin;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI;
import java.util.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings;
import org.springframework.context.event.EventListener;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.*;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
@@ -42,7 +62,7 @@ public class PluginManager {
public static List<Plugin> getPlugins() {
PluginService pluginService = ContextUtils.getBean(PluginService.class);
List<Plugin> pluginList = pluginService.getPluginList().stream().filter(plugin ->
CollectionUtils.isNotEmpty(plugin.getDomainList())).collect(Collectors.toList());
CollectionUtils.isNotEmpty(plugin.getModelList())).collect(Collectors.toList());
pluginList.addAll(internalPluginMap.values());
return new ArrayList<>(pluginList);
}
@@ -89,9 +109,9 @@ public class PluginManager {
doRequest(embeddingConfig.getAddPath(), JSONObject.toJSONString(maps));
}
public void doRequest(String path, String jsonBody) {
public ResponseEntity<String> doRequest(String path, String jsonBody) {
if (Strings.isEmpty(embeddingConfig.getUrl())) {
return;
return ResponseEntity.of(Optional.empty());
}
String url = embeddingConfig.getUrl() + path;
HttpHeaders headers = new HttpHeaders();
@@ -105,6 +125,7 @@ public class PluginManager {
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {
});
log.info("[embedding] result body:{}", responseEntity);
return responseEntity;
}
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
@@ -115,7 +136,8 @@ public class PluginManager {
}
public EmbeddingResp recognize(String embeddingText) {
String url = embeddingConfig.getUrl() + embeddingConfig.getRecognizePath() + "?n_results=" + embeddingConfig.getNResult();
String url = embeddingConfig.getUrl() + embeddingConfig.getRecognizePath() + "?n_results="
+ embeddingConfig.getNResult();
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
@@ -125,8 +147,9 @@ public class PluginManager {
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
ResponseEntity<List<EmbeddingResp>> embeddingResponseEntity =
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, new ParameterizedTypeReference<List<EmbeddingResp>>() {
});
restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
new ParameterizedTypeReference<List<EmbeddingResp>>() {
});
log.info("[embedding] recognize result body:{}", embeddingResponseEntity);
List<EmbeddingResp> embeddingResps = embeddingResponseEntity.getBody();
if (CollectionUtils.isNotEmpty(embeddingResps)) {
@@ -178,4 +201,88 @@ public class PluginManager {
return String.valueOf(Integer.parseInt(id) / 1000);
}
public static Pair<Boolean, List<Long>> resolve(Plugin plugin, QueryContext queryContext) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
Set<Long> pluginMatchedModel = getPluginMatchedModel(plugin, queryContext);
if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) {
return Pair.of(false, Lists.newArrayList());
}
List<ParamOption> paramOptions = getSemanticOption(plugin);
if (CollectionUtils.isEmpty(paramOptions)) {
return Pair.of(true, new ArrayList<>(pluginMatchedModel));
}
List<Long> matchedModel = Lists.newArrayList();
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream().
collect(Collectors.groupingBy(ParamOption::getModelId));
for (Long modelId : paramOptionMap.keySet()) {
List<ParamOption> params = paramOptionMap.get(modelId);
if (CollectionUtils.isEmpty(params)) {
matchedModel.add(modelId);
continue;
}
boolean matched = true;
for (ParamOption paramOption : params) {
Set<Long> elementIdSet = getSchemaElementMatch(modelId, schemaMapInfo);
if (CollectionUtils.isEmpty(elementIdSet)) {
matched = false;
break;
}
if (!elementIdSet.contains(paramOption.getElementId())) {
matched = false;
break;
}
}
if (matched) {
matchedModel.add(modelId);
}
}
if (CollectionUtils.isEmpty(matchedModel)) {
return Pair.of(false, Lists.newArrayList());
}
return Pair.of(true, matchedModel);
}
private static Set<Long> getSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(modelId);
if (org.springframework.util.CollectionUtils.isEmpty(schemaElementMatches)) {
return Sets.newHashSet();
}
return schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()) ||
SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.map(SchemaElementMatch::getElement)
.map(SchemaElement::getId)
.collect(Collectors.toSet());
}
private static List<ParamOption> getSemanticOption(Plugin plugin) {
WebBase webBase = JSONObject.parseObject(plugin.getConfig(), WebBase.class);
if (Objects.isNull(webBase)) {
return null;
}
List<ParamOption> paramOptions = webBase.getParamOptions();
if (org.springframework.util.CollectionUtils.isEmpty(paramOptions)) {
return Lists.newArrayList();
}
return paramOptions.stream()
.filter(paramOption -> ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType()))
.collect(Collectors.toList());
}
private static Set<Long> getPluginMatchedModel(Plugin plugin, QueryContext queryContext) {
Set<Long> matchedModel = queryContext.getMapInfo().getMatchedModels();
if (plugin.isContainsAllModel()) {
return matchedModel;
}
List<Long> modelIds = plugin.getModelList();
Set<Long> pluginMatchedModel = Sets.newHashSet();
for (Long modelId : modelIds) {
if (matchedModel.contains(modelId)) {
pluginMatchedModel.add(modelId);
}
}
return pluginMatchedModel;
}
}

View File

@@ -2,12 +2,19 @@ package com.tencent.supersonic.chat.plugin;
import com.tencent.supersonic.chat.parser.function.Parameters;
import lombok.Data;
import java.io.Serializable;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
@Data
@Builder
@AllArgsConstructor
@ToString
@NoArgsConstructor
public class PluginParseConfig implements Serializable {
private String name;

View File

@@ -0,0 +1,149 @@
package com.tencent.supersonic.chat.query.ContentInterpret;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.pojo.Filter;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.springframework.beans.BeanUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@Slf4j
@Component
public class ContentInterpretQuery extends PluginSemanticQuery {
@Override
public String getQueryMode() {
return "CONTENT_INTERPRET";
}
public ContentInterpretQuery() {
QueryManager.register(this);
}
@Override
public QueryResult execute(User user) throws SqlParseException {
QueryResultWithSchemaResp queryResultWithSchemaResp = queryMetric(user);
String text = generateDataText(queryResultWithSchemaResp);
Map<String, Object> properties = parseInfo.getProperties();
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT))
, PluginParseResult.class);
String answer = fetchInterpret(pluginParseResult.getRequest().getQueryText(), text);
QueryResult queryResult = new QueryResult();
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果", "string", "answer"));
Map<String, Object> result = new HashMap<>();
result.put("answer", answer);
List<Map<String, Object>> resultList = Lists.newArrayList();
resultList.add(result);
queryResultWithSchemaResp.setResultList(resultList);
queryResultWithSchemaResp.setColumns(queryColumns);
queryResult.setResponse(queryResultWithSchemaResp);
queryResult.setQueryMode(getQueryMode());
queryResult.setQueryState(QueryState.SUCCESS);
return queryResult;
}
private QueryResultWithSchemaResp queryMetric(User user) {
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
QueryStructReq queryStructReq = new QueryStructReq();
queryStructReq.setModelId(parseInfo.getModelId());
queryStructReq.setGroups(Lists.newArrayList(TimeDimensionEnum.DAY.getName()));
ModelSchema modelSchema = semanticLayer.getModelSchema(parseInfo.getModelId(), true);
queryStructReq.setAggregators(buildAggregator(modelSchema));
List<Filter> filterList = Lists.newArrayList();
for (QueryFilter queryFilter : parseInfo.getDimensionFilters()) {
Filter filter = new Filter();
BeanUtils.copyProperties(queryFilter, filter);
filterList.add(filter);
}
queryStructReq.setDimensionFilters(filterList);
DateConf dateConf = new DateConf();
dateConf.setDateMode(DateConf.DateMode.RECENT);
dateConf.setUnit(7);
queryStructReq.setDateInfo(dateConf);
return semanticLayer.queryByStruct(queryStructReq, user);
}
private List<Aggregator> buildAggregator(ModelSchema modelSchema) {
List<Aggregator> aggregators = Lists.newArrayList();
Set<SchemaElement> metrics = modelSchema.getMetrics();
if (CollectionUtils.isEmpty(metrics)) {
return aggregators;
}
for (SchemaElement schemaElement : metrics) {
Aggregator aggregator = new Aggregator();
aggregator.setColumn(schemaElement.getBizName());
aggregator.setFunc(AggOperatorEnum.SUM);
aggregator.setNameCh(schemaElement.getName());
aggregators.add(aggregator);
}
return aggregators;
}
public String generateDataText(QueryResultWithSchemaResp queryResultWithSchemaResp) {
Map<String, String> map = queryResultWithSchemaResp.getColumns().stream()
.collect(Collectors.toMap(QueryColumn::getNameEn, QueryColumn::getName));
StringBuilder stringBuilder = new StringBuilder();
for (Map<String, Object> valueMap : queryResultWithSchemaResp.getResultList()) {
for (String key : valueMap.keySet()) {
String name = "";
if (TimeDimensionEnum.getNameList().contains(key)) {
name = "日期";
} else {
name = map.get(key);
}
String value = String.valueOf(valueMap.get(key));
stringBuilder.append(name).append(":").append(value).append(" ");
}
}
return stringBuilder.toString();
}
public String fetchInterpret(String queryText, String dataText) {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
LLmAnswerReq lLmAnswerReq = new LLmAnswerReq();
lLmAnswerReq.setQueryText(queryText);
lLmAnswerReq.setPluginOutput(dataText);
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
JSONObject.toJSONString(lLmAnswerReq));
LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class);
if (lLmAnswerResp != null) {
return lLmAnswerResp.getAssistant_message();
}
return null;
}
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.query.ContentInterpret;
import lombok.Data;
@Data
public class LLmAnswerReq {
private String queryText;
private String pluginOutput;
}

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.chat.query.ContentInterpret;
import lombok.Data;
@Data
public class LLmAnswerResp {
private String assistant_message;
}

View File

@@ -2,20 +2,19 @@ package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.pojo.Constants;
import java.util.*;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import java.util.ArrayList;
import java.util.List;
import java.util.OptionalDouble;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class HeuristicQuerySelector implements QuerySelector {
private static final double MATCH_INHERIT_PENALTY = 0.5;
private static final double MATCH_CURRENT_REWORD = 2;
private static final double CANDIDATE_THRESHOLD = 0.2;
@Override
@@ -26,49 +25,53 @@ public class HeuristicQuerySelector implements QuerySelector {
selectedQueries.addAll(candidateQueries);
} else {
OptionalDouble maxScoreOp = candidateQueries.stream().mapToDouble(
q -> computeScore(q.getParseInfo())).max();
q -> q.getParseInfo().getScore()).max();
if (maxScoreOp.isPresent()) {
double maxScore = maxScoreOp.getAsDouble();
candidateQueries.stream().forEach(query -> {
SemanticParseInfo semanticParse = query.getParseInfo();
if ((maxScore - semanticParse.getScore()) / maxScore <= CANDIDATE_THRESHOLD) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!checkFullyInherited(query)
&& (maxScore - parseInfo.getScore()) / maxScore <= CANDIDATE_THRESHOLD
&& checkSatisfyOtherRules(query, candidateQueries)) {
selectedQueries.add(query);
}
log.info("candidate query (domain={}, queryMode={}) with score={}",
semanticParse.getDomainName(), semanticParse.getQueryMode(), semanticParse.getScore());
log.info("candidate query (Model={}, queryMode={}) with score={}",
parseInfo.getModelName(), parseInfo.getQueryMode(), parseInfo.getScore());
});
}
}
return selectedQueries;
}
private double computeScore(SemanticParseInfo semanticParse) {
double totalScore = 0;
Map<SchemaElementType, SchemaElementMatch> maxSimilarityMatch = new HashMap<>();
for (SchemaElementMatch match : semanticParse.getElementMatches()) {
SchemaElementType type = match.getElement().getType();
if (!maxSimilarityMatch.containsKey(type) ||
match.getSimilarity() > maxSimilarityMatch.get(type).getSimilarity()) {
maxSimilarityMatch.put(type, match);
private boolean checkSatisfyOtherRules(SemanticQuery semanticQuery, List<SemanticQuery> candidateQueries) {
if (!semanticQuery.getQueryMode().equals(MetricModelQuery.QUERY_MODE)) {
return true;
}
for (SemanticQuery candidateQuery : candidateQueries) {
if (candidateQuery.getQueryMode().equals(MetricEntityQuery.QUERY_MODE) &&
semanticQuery.getParseInfo().getScore() == candidateQuery.getParseInfo().getScore()) {
return false;
}
}
return true;
}
for (SchemaElementMatch match : maxSimilarityMatch.values()) {
double matchScore = Optional.ofNullable(match.getDetectWord()).orElse(Constants.EMPTY).length() * match.getSimilarity();
if (match.equals(SchemaElementMatch.MatchMode.INHERIT)) {
matchScore *= MATCH_INHERIT_PENALTY;
} else {
matchScore *= MATCH_CURRENT_REWORD;
}
totalScore += matchScore;
private boolean checkFullyInherited(SemanticQuery query) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!(query instanceof RuleSemanticQuery)) {
return false;
}
// original score in parse info acts like an extra bonus
totalScore += semanticParse.getScore();
semanticParse.setScore(totalScore);
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
if (!match.isInherited()) {
return false;
}
}
if (parseInfo.getDateInfo() != null && !parseInfo.getDateInfo().isInherited()) {
return false;
}
return totalScore;
return true;
}
}

View File

@@ -5,7 +5,6 @@ import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.entity.EntitySemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -56,6 +55,7 @@ public class QueryManager {
throw new RuntimeException("no supported queryMode :" + queryMode);
}
}
public static boolean containsRuleQuery(String queryMode) {
if (queryMode == null) {
return false;

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import java.util.List;
/**

View File

@@ -0,0 +1,78 @@
package com.tencent.supersonic.chat.query.dsl;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class DSLBuilder {
public static final String DATA_Field = "数据日期";
public static final String TABLE_PREFIX = "t_";
public String build(SemanticParseInfo parseInfo, QueryFilters queryFilters, LLMResp llmResp, Long modelId)
throws Exception {
String sqlOutput = llmResp.getSqlOutput();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dbAllFields = new ArrayList<>();
dbAllFields.addAll(semanticSchema.getMetrics());
dbAllFields.addAll(semanticSchema.getDimensions());
Map<String, String> fieldToBizName = getMapInfo(modelId, dbAllFields);
fieldToBizName.put(DATA_Field, TimeDimensionEnum.DAY.getName());
sqlOutput = CCJSqlParserUtils.replaceFields(sqlOutput, fieldToBizName);
sqlOutput = CCJSqlParserUtils.replaceTable(sqlOutput, TABLE_PREFIX + modelId);
String queryFilter = getQueryFilter(queryFilters);
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to sql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
CCJSqlParserUtils.addWhere(sqlOutput, expression);
}
log.info("build sqlOutput:{}", sqlOutput);
return sqlOutput;
}
protected Map<String, String> getMapInfo(Long modelId, List<SchemaElement> metrics) {
return metrics.stream().filter(entry -> entry.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
}
private String getQueryFilter(QueryFilters queryFilters) {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return "";
}
List<QueryFilter> filters = queryFilters.getFilters();
return filters.stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
})
.collect(Collectors.joining(Constants.AND_UPPER));
}
}

View File

@@ -0,0 +1,96 @@
package com.tencent.supersonic.chat.query.dsl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.DSLParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@Slf4j
@Component
public class DSLQuery extends PluginSemanticQuery {
public static final String QUERY_MODE = "DSL";
private DSLBuilder dslBuilder = new DSLBuilder();
protected SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
public DSLQuery() {
QueryManager.register(this);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
@Override
public QueryResult execute(User user) {
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
LLMResp llmResp = dslParseResult.getLlmResp();
QueryReq queryReq = dslParseResult.getRequest();
Long modelId = parseInfo.getModelId();
String querySql = convertToSql(queryReq.getQueryFilters(), llmResp, parseInfo, modelId);
long startTime = System.currentTimeMillis();
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(querySql, modelId);
QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(queryDslReq, user);
log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
QueryResult queryResult = new QueryResult();
if (Objects.nonNull(queryResp)) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
}
String resultQql = queryResp == null ? null : queryResp.getSql();
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>() : queryResp.getResultList();
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
queryResult.setQuerySql(resultQql);
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(columns);
queryResult.setQueryMode(QUERY_MODE);
queryResult.setQueryState(QueryState.SUCCESS);
// add model info
EntityInfo entityInfo = ContextUtils.getBean(SemanticService.class).getEntityInfo(parseInfo, user);
queryResult.setEntityInfo(entityInfo);
parseInfo.setProperties(null);
return queryResult;
}
protected String convertToSql(QueryFilters queryFilters, LLMResp llmResp, SemanticParseInfo parseInfo,
Long modelId) {
try {
return dslBuilder.build(parseInfo, queryFilters, llmResp, modelId);
} catch (Exception e) {
log.error("convertToSql error", e);
}
return null;
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.plugin.dsl;
package com.tencent.supersonic.chat.query.dsl;
import java.util.List;
import lombok.Data;
@@ -12,6 +12,8 @@ public class LLMReq {
private List<ElementValue> linking;
private String currentDate;
@Data
public static class ElementValue {
@@ -26,6 +28,8 @@ public class LLMReq {
private String domainName;
private String modelName;
private List<String> fieldNameList;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.plugin.dsl;
package com.tencent.supersonic.chat.query.dsl;
import java.util.List;
import lombok.Data;
@@ -8,7 +8,7 @@ public class LLMResp {
private String query;
private String domainName;
private String modelName;
private String sqlOutput;

View File

@@ -15,7 +15,7 @@ public class ParamOption {
private String keyAlias;
private Long domainId;
private Long modelId;
private Long elementId;

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.query.plugin;
import com.google.common.collect.Lists;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class WebBase {

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.query.plugin;
import com.google.common.collect.Lists;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class WebBaseResult {

View File

@@ -1,124 +0,0 @@
package com.tencent.supersonic.chat.query.plugin.dsl;
import static java.time.LocalDate.now;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.calcite.SqlParseUtils;
import com.tencent.supersonic.common.util.calcite.SqlParserInfo;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class DSLBuilder {
public static final String COMMA_WRAPPER = "'%s'";
public static final String SPACE_WRAPPER = " %s ";
protected static final String SUB_TABLE = " ( select * from t_{0} where {1} >= ''{2}'' and {1} <= ''{3}'' {4} ) as t_sub_{0}";
public String build(QueryFilters queryFilters, SemanticParseInfo parseInfo, LLMResp llmResp, Long domainId)
throws SqlParseException {
String sqlOutput = llmResp.getSqlOutput();
String domainName = llmResp.getDomainName();
// 1. extra deal with,such as add alias.
sqlOutput = extraConvert(sqlOutput, domainId);
SqlParserInfo sqlParseInfo = SqlParseUtils.getSqlParseInfo(sqlOutput);
String tableName = sqlParseInfo.getTableName();
List<String> allFields = sqlParseInfo.getAllFields();
if (StringUtils.isEmpty(domainName)) {
domainName = tableName;
}
// 2. replace the llm dsl, such as replace fieldName and tableName.
log.info("sqlParseInfo:{} ,domainName:{},domainId:{}", sqlParseInfo, domainName, domainId);
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dbAllFields = new ArrayList<>();
dbAllFields.addAll(semanticSchema.getMetrics());
dbAllFields.addAll(semanticSchema.getDimensions());
Map<String, String> fieldToBizName = getMapInfo(domainId, dbAllFields);
for (String fieldName : allFields) {
String fieldBizName = fieldToBizName.get(fieldName);
if (StringUtils.isNotEmpty(fieldBizName)) {
sqlOutput = sqlOutput.replaceAll(fieldName, fieldBizName);
}
}
//3. deal with dayNo.
DateConf dateInfo = new DateConf();
if (Objects.nonNull(parseInfo) && Objects.nonNull(parseInfo.getDateInfo())) {
dateInfo = parseInfo.getDateInfo();
} else {
String startDate = now().plusDays(-4).toString();
String endDate = now().plusDays(-4).toString();
dateInfo.setStartDate(startDate);
dateInfo.setEndDate(endDate);
}
String startDate = dateInfo.getStartDate();
String endDate = dateInfo.getEndDate();
String period = dateInfo.getPeriod();
TimeDimensionEnum timeDimension = TimeDimensionEnum.valueOf(period);
String dayField = timeDimension.getName();
String queryFilter = getQueryFilter(queryFilters);
String subTable = MessageFormat.format(SUB_TABLE, domainId, dayField, startDate, endDate, queryFilter);
String querySql = sqlOutput.replaceAll(tableName, subTable);
log.info("querySql:{},sqlOutput:{},dateInfo:{}", querySql, sqlOutput, dateInfo);
return querySql;
}
private String getQueryFilter(QueryFilters queryFilters) {
String queryFilter = "";
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return queryFilter;
}
List<QueryFilter> filters = queryFilters.getFilters();
for (QueryFilter filter : filters) {
queryFilter = getSpaceWrap(queryFilter) + "and" + getSpaceWrap(filter.getBizName()) + getSpaceWrap(
filter.getOperator().getValue()) + getCommaWrap(filter.getValue().toString());
}
return queryFilter;
}
protected String extraConvert(String sqlOutput, Long domainId) throws SqlParseException {
return SqlParseUtils.addAliasToSql(sqlOutput);
}
protected Map<String, String> getMapInfo(Long domainId, List<SchemaElement> metrics) {
return metrics.stream().filter(entry -> entry.getDomain().equals(domainId))
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
}
private String getCommaWrap(String value) {
return String.format(COMMA_WRAPPER, value);
}
private String getSpaceWrap(String value) {
return String.format(SPACE_WRAPPER, value);
}
}

View File

@@ -1,204 +0,0 @@
package com.tencent.supersonic.chat.query.plugin.dsl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.config.LLMConfig;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.plugin.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
@Slf4j
@Component
public class DSLQuery extends PluginSemanticQuery {
public static final String QUERY_MODE = "DSL";
private DSLBuilder dslBuilder = new DSLBuilder();
protected SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
public DSLQuery() {
QueryManager.register(this);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
@Override
public QueryResult execute(User user) {
PluginParseResult functionCallParseResult =JsonUtil.toObject(JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT)),PluginParseResult.class);
Long domainId = parseInfo.getDomainId();
LLMResp llmResp = requestLLM(functionCallParseResult, domainId);
if (Objects.isNull(llmResp)) {
return null;
}
String querySql = convertToSql(functionCallParseResult.getRequest().getQueryFilters(), llmResp, parseInfo,
domainId);
QueryResult queryResult = new QueryResult();
long startTime = System.currentTimeMillis();
QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(
QueryReqBuilder.buildDslReq(querySql, domainId), user);
log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
if (queryResp != null) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
}
String resultQql = queryResp == null ? null : queryResp.getSql();
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
: queryResp.getResultList();
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
queryResult.setQuerySql(resultQql);
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(columns);
queryResult.setQueryMode(QUERY_MODE);
queryResult.setQueryState(QueryState.SUCCESS);
// add domain info
EntityInfo entityInfo = ContextUtils.getBean(SemanticService.class)
.getEntityInfo(parseInfo, user);
queryResult.setEntityInfo(entityInfo);
parseInfo.setProperties(null);
return queryResult;
}
protected String convertToSql(QueryFilters queryFilters, LLMResp llmResp, SemanticParseInfo parseInfo,
Long domainId) {
try {
return dslBuilder.build(queryFilters, parseInfo, llmResp, domainId);
} catch (SqlParseException e) {
log.error("convertToSql error", e);
}
return null;
}
protected LLMResp requestLLM(PluginParseResult parseResult, Long domainId) {
long startTime = System.currentTimeMillis();
String queryText = parseResult.getRequest().getQueryText();
final LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class);
if (StringUtils.isEmpty(llmConfig.getUrl())) {
log.warn("llmConfig url is null, skip llm parser");
return null;
}
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<Long, String> domainIdToName = semanticSchema.getDomainIdToName();
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmSchema.setDomainName(domainIdToName.get(domainId));
List<String> fieldNameList = getFieldNameList(domainId, semanticSchema);
llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
linking.addAll(getValueList(domainId, semanticSchema));
llmReq.setLinking(linking);
log.info("requestLLM request, domainId:{},llmReq:{}", domainId, llmReq);
String questUrl = llmConfig.getUrl() + llmConfig.getQueryToSqlPath();
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(questUrl, HttpMethod.POST, entity,
LLMResp.class);
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, questUrl, entity, responseEntity.getBody());
return responseEntity.getBody();
} catch (Exception e) {
log.error("requestLLM error", e);
}
return null;
}
private List<ElementValue> getValueList(Long domainId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = semanticSchema.getDimensions().stream()
.filter(entry -> domainId.equals(entry.getDomain()))
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
List<SchemaElementMatch> matchedElements = parseInfo.getElementMatches();
Set<ElementValue> valueMatches = matchedElements.stream()
.filter(schemaElementMatch -> SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.map(elementMatch ->
{
ElementValue elementValue = new ElementValue();
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
elementValue.setFieldValue(elementMatch.getWord());
return elementValue;
}
)
.collect(Collectors.toSet());
return new ArrayList<>(valueMatches);
}
private List<String> getFieldNameList(Long domainId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = semanticSchema.getDimensions().stream()
.filter(entry -> domainId.equals(entry.getDomain()))
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
List<SchemaElementMatch> matchedElements = parseInfo.getElementMatches();
Set<String> fieldNameList = matchedElements.stream()
.filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType) ||
SchemaElementType.DIMENSION.equals(elementType) ||
SchemaElementType.VALUE.equals(elementType);
})
.map(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
if (!SchemaElementType.VALUE.equals(elementType)) {
return schemaElementMatch.getWord();
}
return itemIdToName.get(schemaElementMatch.getElement().getId());
})
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
.collect(Collectors.toSet());
return new ArrayList<>(fieldNameList);
}
}

View File

@@ -2,11 +2,15 @@ package com.tencent.supersonic.chat.query.plugin.webpage;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
@@ -14,18 +18,18 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.query.plugin.WebBaseResult;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.stream.Collectors;
@Slf4j
@Component
public class WebPageQuery extends PluginSemanticQuery {
@@ -43,51 +47,45 @@ public class WebPageQuery extends PluginSemanticQuery {
@Override
public QueryResult execute(User user) {
ConfigService configService = ContextUtils.getBean(ConfigService.class);
QueryResult queryResult = new QueryResult();
queryResult.setQueryMode(QUERY_MODE);
Map<String, Object> properties = parseInfo.getProperties();
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT)), PluginParseResult.class);
WebPageResponse webPageResponse = buildResponse(pluginParseResult.getPlugin());
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT))
, PluginParseResult.class);
WebPageResponse webPageResponse = buildResponse(pluginParseResult);
queryResult.setResponse(webPageResponse);
if (parseInfo.getDomainId() != null && parseInfo.getDomainId() > 0
&& parseInfo.getEntity() != null && Objects.nonNull(parseInfo.getEntity().getId())
&& parseInfo.getEntity().getId() > 0) {
ChatConfigRichResp chatConfigRichResp = configService.getConfigRichInfo(parseInfo.getDomainId());
updateSemanticParse(chatConfigRichResp);
EntityInfo entityInfo = ContextUtils.getBean(SemanticService.class).getEntityInfo(parseInfo, user);
queryResult.setEntityInfo(entityInfo);
} else {
queryResult.setEntityInfo(null);
}
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema modelSchema = semanticService.getModelSchema(parseInfo.getModelId());
parseInfo.setModel(modelSchema.getModel());
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, user);
queryResult.setEntityInfo(entityInfo);
queryResult.setQueryState(QueryState.SUCCESS);
return queryResult;
}
private void updateSemanticParse(ChatConfigRichResp chatConfigRichResp) {
SchemaElement domain = new SchemaElement();
domain.setId(chatConfigRichResp.getDomainId());
domain.setName(chatConfigRichResp.getDomainName());
parseInfo.setDomain(domain);
}
protected WebPageResponse buildResponse(Plugin plugin) {
protected WebPageResponse buildResponse(PluginParseResult pluginParseResult) {
Plugin plugin = pluginParseResult.getPlugin();
WebPageResponse webPageResponse = new WebPageResponse();
webPageResponse.setName(plugin.getName());
webPageResponse.setPluginId(plugin.getId());
webPageResponse.setPluginType(plugin.getType());
WebBase webPage = JsonUtil.toObject(plugin.getConfig(), WebBase.class);
WebBaseResult webBaseResult = buildWebPageResult(webPage);
WebBaseResult webBaseResult = buildWebPageResult(webPage, pluginParseResult);
webPageResponse.setWebPage(webBaseResult);
return webPageResponse;
}
private WebBaseResult buildWebPageResult(WebBase webPage) {
private WebBaseResult buildWebPageResult(WebBase webPage, PluginParseResult pluginParseResult) {
WebBaseResult webBaseResult = new WebBaseResult();
webBaseResult.setUrl(webPage.getUrl());
Map<String, Object> elementValueMap = getElementMap();
Map<String, Object> elementValueMap = getElementMap(pluginParseResult);
List<ParamOption> paramOptions = Lists.newArrayList();
if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) {
for (ParamOption paramOption : webPage.getParamOptions()) {
if (paramOption.getModelId() != null && !paramOption.getModelId().equals(parseInfo.getModelId())) {
continue;
}
paramOptions.add(paramOption);
if (!ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType())) {
continue;
}
@@ -96,12 +94,13 @@ public class WebPageQuery extends PluginSemanticQuery {
paramOption.setValue(elementValue);
}
}
webBaseResult.setParams(webPage.getParamOptions());
webBaseResult.setParams(paramOptions);
return webBaseResult;
}
private Map<String, Object> getElementMap() {
protected Map<String, Object> getElementMap(PluginParseResult pluginParseResult) {
Map<String, Object> elementValueMap = new HashMap<>();
Map<Long, Object> filterValueMap = getFilterMap(pluginParseResult);
List<SchemaElementMatch> schemaElementMatchList = parseInfo.getElementMatches();
if (!CollectionUtils.isEmpty(schemaElementMatchList)) {
schemaElementMatchList.stream()
@@ -109,11 +108,37 @@ public class WebPageQuery extends PluginSemanticQuery {
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.sorted(Comparator.comparingDouble(SchemaElementMatch::getSimilarity))
.forEach(schemaElementMatch ->
.forEach(schemaElementMatch -> {
Object queryFilterValue = filterValueMap.get(schemaElementMatch.getElement().getId());
if (queryFilterValue != null) {
if (String.valueOf(queryFilterValue).equals(String.valueOf(schemaElementMatch.getWord()))) {
elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()),
schemaElementMatch.getWord());
}
} else {
elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()),
schemaElementMatch.getWord()));
schemaElementMatch.getWord());
}
});
}
return elementValueMap;
}
private Map<Long, Object> getFilterMap(PluginParseResult pluginParseResult) {
Map<Long, Object> map = new HashMap<>();
QueryReq queryReq = pluginParseResult.getRequest();
if (queryReq == null || queryReq.getQueryFilters() == null) {
return map;
}
QueryFilters queryFilters = queryReq.getQueryFilters();
List<QueryFilter> queryFilterList = queryFilters.getFilters();
if (CollectionUtils.isEmpty(queryFilterList)) {
return map;
}
for (QueryFilter queryFilter : queryFilterList) {
map.put(queryFilter.getElementID(), queryFilter.getValue());
}
return map;
}
}

View File

@@ -1,9 +1,8 @@
package com.tencent.supersonic.chat.query.plugin.webpage;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.query.plugin.WebBaseResult;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class WebPageResponse {

View File

@@ -11,22 +11,23 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.*;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import java.net.URI;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.springframework.http.*;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.util.HashMap;
import java.util.Map;
@Slf4j
@Component
public class WebServiceQuery extends PluginSemanticQuery {
@@ -49,11 +50,12 @@ public class WebServiceQuery extends PluginSemanticQuery {
QueryResult queryResult = new QueryResult();
queryResult.setQueryMode(QUERY_MODE);
Map<String, Object> properties = parseInfo.getProperties();
PluginParseResult pluginParseResult =JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT)),PluginParseResult.class);
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT)),
PluginParseResult.class);
WebServiceResponse webServiceResponse = buildResponse(pluginParseResult);
queryResult.setResponse(webServiceResponse);
queryResult.setQueryState(QueryState.SUCCESS);
parseInfo.setProperties(null);
//parseInfo.setProperties(null);
return queryResult;
}

View File

@@ -2,8 +2,6 @@ package com.tencent.supersonic.chat.query.plugin.webservice;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
@Data
public class WebServiceResponse {

View File

@@ -1,12 +1,8 @@
package com.tencent.supersonic.chat.query.rule;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -15,7 +11,6 @@ import java.util.Map;
import java.util.Objects;
import lombok.Data;
import lombok.ToString;
import org.springframework.util.CollectionUtils;
@Data
@ToString
@@ -29,7 +24,7 @@ public class QueryMatcher {
public QueryMatcher() {
for (SchemaElementType type : SchemaElementType.values()) {
if (type.equals(SchemaElementType.DOMAIN)) {
if (type.equals(SchemaElementType.MODEL)) {
elementOptionMap.put(type, QueryMatchOption.optional());
} else {
elementOptionMap.put(type, QueryMatchOption.unused());

View File

@@ -1,10 +1,15 @@
package com.tencent.supersonic.chat.query.rule;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
@@ -13,19 +18,21 @@ import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import java.io.Serializable;
import java.util.*;
@Slf4j
@ToString
@@ -40,35 +47,57 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
}
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
QueryContext queryCtx) {
QueryContext queryCtx) {
return queryMatcher.match(candidateElementMatches);
}
public void fillParseInfo(Long domainId, ChatContext chatContext) {
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
parseInfo.setQueryMode(getQueryMode());
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = schemaService.getDomainSchema(domainId);
ModelSchema ModelSchema = schemaService.getModelSchema(modelId);
fillSchemaElement(parseInfo, domainSchema);
// inherit date info from context
if (parseInfo.getDateInfo() == null && chatContext.getParseInfo().getDateInfo() != null
&& isSameQueryMode(getQueryMode(), chatContext.getParseInfo().getQueryMode())) {
log.info("inherit date info from context");
parseInfo.setDateInfo(chatContext.getParseInfo().getDateInfo());
fillSchemaElement(parseInfo, ModelSchema);
fillScore(parseInfo);
fillDateConf(parseInfo, chatContext.getParseInfo());
}
private void fillDateConf(SemanticParseInfo queryParseInfo, SemanticParseInfo chatParseInfo) {
if (queryParseInfo.getDateInfo() != null || chatParseInfo.getDateInfo() == null) {
return;
}
if ((QueryManager.isEntityQuery(queryParseInfo.getQueryMode())
&& QueryManager.isEntityQuery(chatParseInfo.getQueryMode()))
|| (QueryManager.isMetricQuery(queryParseInfo.getQueryMode())
&& QueryManager.isMetricQuery(chatParseInfo.getQueryMode()))) {
// inherit date info from context
queryParseInfo.setDateInfo(chatParseInfo.getDateInfo());
queryParseInfo.getDateInfo().setInherited(true);
}
}
public boolean isSameQueryMode(String queryModeQuery, String queryModeChat) {
if (Strings.isNotEmpty(queryModeQuery) && Strings.isNotEmpty(queryModeChat)) {
return QueryManager.isEntityQuery(queryModeQuery) && QueryManager.isEntityQuery(queryModeChat)
|| QueryManager.isMetricQuery(queryModeQuery) && QueryManager.isMetricQuery(queryModeChat);
private void fillScore(SemanticParseInfo parseInfo) {
double totalScore = 0;
Map<SchemaElementType, SchemaElementMatch> maxSimilarityMatch = new HashMap<>();
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
SchemaElementType type = match.getElement().getType();
if (!maxSimilarityMatch.containsKey(type) ||
match.getSimilarity() > maxSimilarityMatch.get(type).getSimilarity()) {
maxSimilarityMatch.put(type, match);
}
}
return true;
for (SchemaElementMatch match : maxSimilarityMatch.values()) {
totalScore += match.getDetectWord().length() * match.getSimilarity();
}
parseInfo.setScore(parseInfo.getScore() + totalScore);
}
private void fillSchemaElement(SemanticParseInfo parseInfo, DomainSchema domainSchema) {
parseInfo.setDomain(domainSchema.getDomain());
private void fillSchemaElement(SemanticParseInfo parseInfo, ModelSchema ModelSchema) {
parseInfo.setModel(ModelSchema.getModel());
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
Map<Long, List<SchemaElementMatch>> id2Values = new HashMap<>();
@@ -77,7 +106,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
SchemaElement element = schemaMatch.getElement();
switch (element.getType()) {
case ID:
SchemaElement entityElement = domainSchema.getElement(SchemaElementType.ENTITY, element.getId());
SchemaElement entityElement = ModelSchema.getElement(SchemaElementType.ENTITY, element.getId());
if (entityElement != null) {
if (id2Values.containsKey(element.getId())) {
id2Values.get(element.getId()).add(schemaMatch);
@@ -87,7 +116,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
}
break;
case VALUE:
SchemaElement dimElement = domainSchema.getElement(SchemaElementType.DIMENSION, element.getId());
SchemaElement dimElement = ModelSchema.getElement(SchemaElementType.DIMENSION, element.getId());
if (dimElement != null) {
if (dim2Values.containsKey(element.getId())) {
dim2Values.get(element.getId()).add(schemaMatch);
@@ -111,7 +140,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
if (!id2Values.isEmpty()) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : id2Values.entrySet()) {
SchemaElement entity = domainSchema.getElement(SchemaElementType.ENTITY, entry.getKey());
SchemaElement entity = ModelSchema.getElement(SchemaElementType.ENTITY, entry.getKey());
if (entry.getValue().size() == 1) {
SchemaElementMatch schemaMatch = entry.getValue().get(0);
@@ -122,7 +151,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
dimensionFilter.setElementID(schemaMatch.getElement().getId());
parseInfo.getDimensionFilters().add(dimensionFilter);
parseInfo.setEntity(domainSchema.getEntity());
parseInfo.setEntity(ModelSchema.getEntity());
} else {
QueryFilter dimensionFilter = new QueryFilter();
List<String> vals = new ArrayList<>();
@@ -139,7 +168,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
if (!dim2Values.isEmpty()) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dim2Values.entrySet()) {
SchemaElement dimension = domainSchema.getElement(SchemaElementType.DIMENSION, entry.getKey());
SchemaElement dimension = ModelSchema.getElement(SchemaElementType.DIMENSION, entry.getKey());
if (entry.getValue().size() == 1) {
SchemaElementMatch schemaMatch = entry.getValue().get(0);
@@ -150,7 +179,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
dimensionFilter.setElementID(schemaMatch.getElement().getId());
parseInfo.getDimensionFilters().add(dimensionFilter);
parseInfo.setEntity(domainSchema.getEntity());
parseInfo.setEntity(ModelSchema.getEntity());
} else {
QueryFilter dimensionFilter = new QueryFilter();
List<String> vals = new ArrayList<>();
@@ -171,7 +200,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
public QueryResult execute(User user) {
String queryMode = parseInfo.getQueryMode();
if (parseInfo.getDomainId() < 0 || StringUtils.isEmpty(queryMode)
if (parseInfo.getModelId() < 0 || StringUtils.isEmpty(queryMode)
|| !QueryManager.containsRuleQuery(queryMode)) {
// reach here some error may happen
log.error("not find QueryMode");
@@ -195,7 +224,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
queryResult.setQueryMode(queryMode);
queryResult.setQueryState(QueryState.SUCCESS);
// add domain info
// add Model info
EntityInfo entityInfo = ContextUtils.getBean(SemanticService.class)
.getEntityInfo(parseInfo, user);
queryResult.setEntityInfo(entityInfo);
@@ -205,7 +234,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
public QueryResult multiStructExecute(User user) {
String queryMode = parseInfo.getQueryMode();
if (parseInfo.getDomainId() < 0 || StringUtils.isEmpty(queryMode)
if (parseInfo.getModelId() < 0 || StringUtils.isEmpty(queryMode)
|| !QueryManager.containsRuleQuery(queryMode)) {
// reach here some error may happen
log.error("not find QueryMode");
@@ -228,7 +257,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
queryResult.setQueryMode(queryMode);
queryResult.setQueryState(QueryState.SUCCESS);
// add domain info
// add Model info
EntityInfo entityInfo = ContextUtils.getBean(SemanticService.class)
.getEntityInfo(parseInfo, user);
queryResult.setEntityInfo(entityInfo);
@@ -246,8 +275,9 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
}
public static List<RuleSemanticQuery> resolve(List<SchemaElementMatch> candidateElementMatches,
QueryContext queryContext) {
QueryContext queryContext) {
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
List<SchemaElementMatch> matches = semanticQuery.match(candidateElementMatches, queryContext);
@@ -261,6 +291,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
return matchedQueries;
}
protected QueryStructReq convertQueryStruct() {
return QueryReqBuilder.buildStructReq(parseInfo);
}

View File

@@ -1,8 +1,9 @@
package com.tencent.supersonic.chat.query.rule.entity;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import org.springframework.stereotype.Component;

View File

@@ -1,8 +1,9 @@
package com.tencent.supersonic.chat.query.rule.entity;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.*;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

View File

@@ -1,7 +1,8 @@
package com.tencent.supersonic.chat.query.rule.entity;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.DomainSchema;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
@@ -11,7 +12,6 @@ import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Set;
@@ -19,25 +19,26 @@ import java.util.Set;
public abstract class EntityListQuery extends EntitySemanticQuery {
@Override
public void fillParseInfo(Long domainId, ChatContext chatContext) {
super.fillParseInfo(domainId, chatContext);
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
super.fillParseInfo(modelId, queryContext, chatContext);
this.addEntityDetailAndOrderByMetric(parseInfo);
}
private void addEntityDetailAndOrderByMetric(SemanticParseInfo parseInfo) {
Long domainId = parseInfo.getDomainId();
if (Objects.nonNull(domainId) && domainId > 0L) {
Long modelId = parseInfo.getModelId();
if (Objects.nonNull(modelId) && modelId > 0L) {
ConfigService configService = ContextUtils.getBean(ConfigService.class);
ChatConfigRichResp chaConfigRichDesc = configService.getConfigRichInfo(parseInfo.getDomainId());
ChatConfigRichResp chaConfigRichDesc = configService.getConfigRichInfo(parseInfo.getModelId());
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = schemaService.getDomainSchema(domainId);
ModelSchema ModelSchema = schemaService.getModelSchema(modelId);
if (chaConfigRichDesc != null && chaConfigRichDesc.getChatDetailRichConfig() != null
&& Objects.nonNull(domainSchema) && Objects.nonNull(domainSchema.getEntity())) {
&& Objects.nonNull(ModelSchema) && Objects.nonNull(ModelSchema.getEntity())) {
Set<SchemaElement> dimensions = new LinkedHashSet();
Set<SchemaElement> metrics = new LinkedHashSet();
Set<Order> orders = new LinkedHashSet();
ChatDefaultRichConfigResp chatDefaultConfig = chaConfigRichDesc.getChatDetailRichConfig().getChatDefaultConfig();
ChatDefaultRichConfigResp chatDefaultConfig = chaConfigRichDesc.getChatDetailRichConfig()
.getChatDefaultConfig();
if (chatDefaultConfig != null) {
chatDefaultConfig.getMetrics().stream()
.forEach(metric -> {

View File

@@ -1,5 +1,9 @@
package com.tencent.supersonic.chat.query.rule.entity;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
@@ -11,17 +15,12 @@ import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public abstract class EntitySemanticQuery extends RuleSemanticQuery {
@@ -35,7 +34,7 @@ public abstract class EntitySemanticQuery extends RuleSemanticQuery {
@Override
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
QueryContext queryCtx) {
QueryContext queryCtx) {
candidateElementMatches = filterElementMatches(candidateElementMatches);
return super.match(candidateElementMatches, queryCtx);
}
@@ -43,13 +42,13 @@ public abstract class EntitySemanticQuery extends RuleSemanticQuery {
private List<SchemaElementMatch> filterElementMatches(List<SchemaElementMatch> candidateElementMatches) {
List<SchemaElementMatch> filteredMatches = new ArrayList<>();
if (CollectionUtils.isEmpty(candidateElementMatches)
|| Objects.isNull(candidateElementMatches.get(0).getElement().getDomain())) {
|| Objects.isNull(candidateElementMatches.get(0).getElement().getModel())) {
return candidateElementMatches;
}
Long domainId = candidateElementMatches.get(0).getElement().getDomain();
Long modelId = candidateElementMatches.get(0).getElement().getModel();
ConfigService configService = ContextUtils.getBean(ConfigService.class);
ChatConfigResp chatConfig = configService.fetchConfigByDomainId(domainId);
ChatConfigResp chatConfig = configService.fetchConfigByModelId(modelId);
List<Long> blackDimIdList = new ArrayList<>();
List<Long> blackMetricIdList = new ArrayList<>();
@@ -78,14 +77,14 @@ public abstract class EntitySemanticQuery extends RuleSemanticQuery {
}
@Override
public void fillParseInfo(Long domainId, ChatContext chatContext) {
super.fillParseInfo(domainId, chatContext);
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
super.fillParseInfo(modelId, queryContext, chatContext);
parseInfo.setNativeQuery(true);
parseInfo.setLimit(ENTITY_MAX_RESULTS);
if (parseInfo.getDateInfo() == null) {
ConfigService configService = ContextUtils.getBean(ConfigService.class);
ChatConfigRichResp chatConfig = configService.getConfigRichInfo(parseInfo.getDomainId());
ChatConfigRichResp chatConfig = configService.getConfigRichInfo(parseInfo.getModelId());
ChatDefaultRichConfigResp defaultConfig = chatConfig.getChatDetailRichConfig().getChatDefaultConfig();
int unit = 1;

View File

@@ -1,24 +1,23 @@
package com.tencent.supersonic.chat.query.rule.metric;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.semantic.api.query.pojo.Filter;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@Slf4j
@Component
@@ -83,7 +82,8 @@ public class MetricEntityQuery extends MetricSemanticQuery {
filters.forEach(d -> {
if (!dimensions.contains(d.getBizName())) {
dimensions.add(d.getBizName());
}});
}
});
queryStructReq.setGroups(dimensions);
log.info("addDimension after [{}]", queryStructReq.getGroups());
}

View File

@@ -79,7 +79,8 @@ public class MetricFilterQuery extends MetricSemanticQuery {
filters.forEach(d -> {
if (!dimensions.contains(d.getBizName())) {
dimensions.add(d.getBizName());
}});
}
});
queryStructReq.setGroups(dimensions);
log.info("addDimension after [{}]", queryStructReq.getGroups());
}

View File

@@ -1,11 +1,12 @@
package com.tencent.supersonic.chat.query.rule.metric;
import org.springframework.stereotype.Component;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import org.springframework.stereotype.Component;
@Component
public class MetricGroupByQuery extends MetricSemanticQuery {

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.query.rule.metric;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DOMAIN;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_MOST;
@@ -9,13 +9,13 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import org.springframework.stereotype.Component;
@Component
public class MetricDomainQuery extends MetricSemanticQuery {
public class MetricModelQuery extends MetricSemanticQuery {
public static final String QUERY_MODE = "METRIC_DOMAIN";
public static final String QUERY_MODE = "METRIC_MODEL";
public MetricDomainQuery() {
public MetricModelQuery() {
super();
queryMatcher.addOption(DOMAIN, OPTIONAL, AT_MOST, 1);
queryMatcher.addOption(MODEL, OPTIONAL, AT_MOST, 1);
}
@Override

View File

@@ -40,7 +40,7 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
@Override
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
QueryContext queryCtx) {
QueryContext queryCtx) {
candidateElementMatches = filterElementMatches(candidateElementMatches);
return super.match(candidateElementMatches, queryCtx);
}
@@ -48,13 +48,13 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
private List<SchemaElementMatch> filterElementMatches(List<SchemaElementMatch> candidateElementMatches) {
List<SchemaElementMatch> filteredMatches = new ArrayList<>();
if (CollectionUtils.isEmpty(candidateElementMatches)
|| Objects.isNull(candidateElementMatches.get(0).getElement().getDomain())) {
|| Objects.isNull(candidateElementMatches.get(0).getElement().getModel())) {
return candidateElementMatches;
}
Long domainId = candidateElementMatches.get(0).getElement().getDomain();
Long modelId = candidateElementMatches.get(0).getElement().getModel();
ConfigService configService = ContextUtils.getBean(ConfigService.class);
ChatConfigResp chatConfig = configService.fetchConfigByDomainId(domainId);
ChatConfigResp chatConfig = configService.fetchConfigByModelId(modelId);
List<Long> blackDimIdList = new ArrayList<>();
List<Long> blackMetricIdList = new ArrayList<>();
@@ -83,13 +83,13 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
}
@Override
public void fillParseInfo(Long domainId, ChatContext chatContext) {
super.fillParseInfo(domainId, chatContext);
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
super.fillParseInfo(modelId, queryContext, chatContext);
parseInfo.setLimit(METRIC_MAX_RESULTS);
if (parseInfo.getDateInfo() == null) {
ConfigService configService = ContextUtils.getBean(ConfigService.class);
ChatConfigRichResp chatConfig = configService.getConfigRichInfo(parseInfo.getDomainId());
ChatConfigRichResp chatConfig = configService.getConfigRichInfo(parseInfo.getModelId());
ChatDefaultRichConfigResp defaultConfig = chatConfig.getChatAggRichConfig().getChatDefaultConfig();
DateConf dateInfo = new DateConf();
int unit = 1;

View File

@@ -3,8 +3,8 @@ package com.tencent.supersonic.chat.query.rule.metric;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.common.pojo.Constants.DESC_UPPER;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
@@ -13,12 +13,11 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.springframework.stereotype.Component;
@Component
public class MetricTopNQuery extends MetricSemanticQuery {
@@ -36,7 +35,7 @@ public class MetricTopNQuery extends MetricSemanticQuery {
@Override
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
QueryContext queryCtx) {
QueryContext queryCtx) {
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getRequest().getQueryText());
if (matcher.matches()) {
return super.match(candidateElementMatches, queryCtx);
@@ -50,11 +49,11 @@ public class MetricTopNQuery extends MetricSemanticQuery {
}
@Override
public void fillParseInfo(Long domainId, ChatContext chatContext){
super.fillParseInfo(domainId, chatContext);
public void fillParseInfo(Long modelId, QueryContext queryContext, ChatContext chatContext) {
super.fillParseInfo(modelId, queryContext, chatContext);
parseInfo.setLimit(ORDERBY_MAX_RESULTS);
parseInfo.setScore(2.0);
parseInfo.setScore(parseInfo.getScore() + 2.0);
parseInfo.setAggType(AggregateTypeEnum.SUM);
SchemaElement metric = parseInfo.getMetrics().iterator().next();

View File

@@ -9,18 +9,18 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
@@ -44,16 +44,16 @@ public class ChatConfigController {
@PostMapping
public Long addChatConfig(@RequestBody ChatConfigBaseReq extendBaseCmd,
HttpServletRequest request,
HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return configService.addConfig(extendBaseCmd, user);
}
@PutMapping
public Long editDomainExtend(@RequestBody ChatConfigEditReqReq extendEditCmd,
HttpServletRequest request,
HttpServletResponse response) {
public Long editModelExtend(@RequestBody ChatConfigEditReqReq extendEditCmd,
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return configService.editConfig(extendEditCmd, user);
}
@@ -61,16 +61,16 @@ public class ChatConfigController {
@PostMapping("/search")
public List<ChatConfigResp> search(@RequestBody ChatConfigFilter filter,
HttpServletRequest request,
HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return configService.search(filter, user);
}
@GetMapping("/richDesc/{domainId}")
public ChatConfigRichResp getDomainExtendRichInfo(@PathVariable("domainId") Long domainId) {
return configService.getConfigRichInfo(domainId);
@GetMapping("/richDesc/{modelId}")
public ChatConfigRichResp getModelExtendRichInfo(@PathVariable("modelId") Long modelId) {
return configService.getConfigRichInfo(modelId);
}
@GetMapping("/richDesc/all")
@@ -78,34 +78,46 @@ public class ChatConfigController {
return configService.getAllChatRichConfig();
}
/**
* get domain list
*
* @param
*/
@GetMapping("/domainList")
public List<DomainResp> getDomainList() {
return semanticLayer.getDomainListForAdmin();
@GetMapping("/modelList/{domainId}")
public List<ModelResp> getModelList(@PathVariable("domainId") Long domainId,
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return semanticLayer.getModelList(AuthType.ADMIN, domainId, user);
}
@GetMapping("/domainList/view")
public List<DomainResp> getDomainListForViewer() {
return semanticLayer.getDomainListForViewer();
@GetMapping("/domainList")
public List<DomainResp> getDomainList(HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return semanticLayer.getDomainList(user);
}
@GetMapping("/modelList")
public List<ModelResp> getModelList(HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return semanticLayer.getModelList(AuthType.ADMIN, null, user);
}
@GetMapping("/modelList/view")
public List<ModelResp> getModelListVisible(HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response);
return semanticLayer.getModelList(AuthType.VISIBLE, null, user);
}
@PostMapping("/dimension/page")
public PageInfo<DimensionResp> getDimension(@RequestBody PageDimensionReq pageDimensionCmd,
HttpServletRequest request,
HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
return semanticLayer.getDimensionPage(pageDimensionCmd);
}
@PostMapping("/metric/page")
public PageInfo<MetricResp> getMetric(@RequestBody PageMetricReq pageMetrricCmd,
HttpServletRequest request,
HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
return semanticLayer.getMetricPage(pageMetrricCmd);
}

View File

@@ -3,9 +3,9 @@ package com.tencent.supersonic.chat.rest;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.service.ChatService;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
@@ -69,9 +69,9 @@ public class ChatController {
@PostMapping("/pageQueryInfo")
public PageInfo<QueryResp> pageQueryInfo(@RequestBody PageQueryInfoReq pageQueryInfoCommand,
@RequestParam(value = "chatId") long chatId,
HttpServletRequest request,
HttpServletResponse response) {
@RequestParam(value = "chatId") long chatId,
HttpServletRequest request,
HttpServletResponse response) {
pageQueryInfoCommand.setUserName(UserHolder.findUser(request, response).getName());
return chatService.queryInfo(pageQueryInfoCommand, chatId);
}

View File

@@ -3,8 +3,8 @@ package com.tencent.supersonic.chat.rest;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.service.SearchService;
import javax.servlet.http.HttpServletRequest;
@@ -33,7 +33,7 @@ public class ChatQueryController {
@PostMapping("search")
public Object search(@RequestBody QueryReq queryCtx, HttpServletRequest request,
HttpServletResponse response) {
HttpServletResponse response) {
queryCtx.setUser(UserHolder.findUser(request, response));
return searchService.search(queryCtx);
}
@@ -53,7 +53,8 @@ public class ChatQueryController {
}
@PostMapping("execute")
public Object execute(@RequestBody ExecuteQueryReq queryCtx, HttpServletRequest request, HttpServletResponse response)
public Object execute(@RequestBody ExecuteQueryReq queryCtx, HttpServletRequest request,
HttpServletResponse response)
throws Exception {
queryCtx.setUser(UserHolder.findUser(request, response));
return queryService.performExecution(queryCtx);
@@ -61,13 +62,14 @@ public class ChatQueryController {
@PostMapping("queryContext")
public Object queryContext(@RequestBody QueryReq queryCtx, HttpServletRequest request,
HttpServletResponse response) throws Exception {
HttpServletResponse response) throws Exception {
queryCtx.setUser(UserHolder.findUser(request, response));
return queryService.queryContext(queryCtx);
}
@PostMapping("queryData")
public Object queryData(@RequestBody QueryDataReq queryData, HttpServletRequest request, HttpServletResponse response)
public Object queryData(@RequestBody QueryDataReq queryData, HttpServletRequest request,
HttpServletResponse response)
throws Exception {
return queryService.executeDirectQuery(queryData, UserHolder.findUser(request, response));
}

View File

@@ -10,7 +10,6 @@ import com.tencent.supersonic.knowledge.dictionary.DimValueDictInfo;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.GetMapping;

View File

@@ -5,10 +5,16 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.service.PluginService;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.List;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/api/chat/plugin")
@@ -22,8 +28,8 @@ public class PluginController {
@PostMapping
public boolean createPlugin(@RequestBody Plugin plugin,
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
pluginService.createPlugin(plugin, user);
return true;
@@ -31,8 +37,8 @@ public class PluginController {
@PutMapping
public boolean updatePlugin(@RequestBody Plugin plugin,
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
pluginService.updatePlugin(plugin, user);
return true;
@@ -50,8 +56,11 @@ public class PluginController {
}
@PostMapping("/query")
List<Plugin> query(@RequestBody PluginQueryReq pluginQueryReq) {
return pluginService.queryWithAuthCheck(pluginQueryReq);
List<Plugin> query(@RequestBody PluginQueryReq pluginQueryReq,
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
return pluginService.queryWithAuthCheck(pluginQueryReq, user);
}
}

View File

@@ -5,14 +5,15 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
import com.tencent.supersonic.chat.service.RecommendService;
import java.util.List;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
/**
* recommend controller
@@ -24,32 +25,33 @@ public class RecommendController {
@Autowired
private RecommendService recommendService;
@GetMapping("recommend/{domainId}")
public RecommendResp recommend(@PathVariable("domainId") Long domainId,
@RequestParam(value = "limit", required = false) Long limit,
HttpServletRequest request,
HttpServletResponse response) {
@GetMapping("recommend/{modelId}")
public RecommendResp recommend(@PathVariable("modelId") Long modelId,
@RequestParam(value = "limit", required = false) Long limit,
HttpServletRequest request,
HttpServletResponse response) {
QueryReq queryCtx = new QueryReq();
queryCtx.setUser(UserHolder.findUser(request, response));
queryCtx.setDomainId(domainId);
queryCtx.setModelId(modelId);
return recommendService.recommend(queryCtx, limit);
}
@GetMapping("recommend/metric/{domainId}")
public RecommendResp recommendMetricMode(@PathVariable("domainId") Long domainId,
@RequestParam(value = "limit", required = false) Long limit,
HttpServletRequest request,
HttpServletResponse response) {
@GetMapping("recommend/metric/{modelId}")
public RecommendResp recommendMetricMode(@PathVariable("modelId") Long modelId,
@RequestParam(value = "limit", required = false) Long limit,
HttpServletRequest request,
HttpServletResponse response) {
QueryReq queryCtx = new QueryReq();
queryCtx.setUser(UserHolder.findUser(request, response));
queryCtx.setDomainId(domainId);
queryCtx.setModelId(modelId);
return recommendService.recommendMetricMode(queryCtx, limit);
}
@GetMapping("recommend/question")
public List<RecommendQuestionResp> recommendQuestion(@RequestParam(value = "domainId", required = false) Long domainId,
HttpServletRequest request,
HttpServletResponse response) {
return recommendService.recommendQuestion(domainId);
public List<RecommendQuestionResp> recommendQuestion(
@RequestParam(value = "modelId", required = false) Long modelId,
HttpServletRequest request,
HttpServletResponse response) {
return recommendService.recommendQuestion(modelId);
}
}

View File

@@ -3,23 +3,21 @@ package com.tencent.supersonic.chat.service;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import java.util.List;
public interface ChatService {
/***
* get the domain from context
* get the model from context
* @param chatId
* @return
*/
public Long getContextDomain(Integer chatId);
public Long getContextModel(Integer chatId);
public ChatContext getOrCreateContext(int chatId);

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import java.util.List;
public interface ConfigService {
@@ -18,9 +17,9 @@ public interface ConfigService {
List<ChatConfigResp> search(ChatConfigFilter filter, User user);
ChatConfigRichResp getConfigRichInfo(Long domainId);
ChatConfigRichResp getConfigRichInfo(Long modelId);
ChatConfigResp fetchConfigByDomainId(Long domainId);
ChatConfigResp fetchConfigByModelId(Long modelId);
List<ChatConfigRichResp> getAllChatRichConfig();
}

View File

@@ -5,17 +5,17 @@ import com.tencent.supersonic.knowledge.dictionary.DictConfig;
import com.tencent.supersonic.knowledge.dictionary.DictTaskFilter;
import com.tencent.supersonic.knowledge.dictionary.DimValue2DictCommand;
import com.tencent.supersonic.knowledge.dictionary.DimValueDictInfo;
import java.util.List;
public interface DictionaryService {
Long addDictTask(DimValue2DictCommand dimValue2DictCommend, User user);
Long deleteDictTask(DimValue2DictCommand dimValue2DictCommend, User user);
List<DimValueDictInfo> searchDictTaskList(DictTaskFilter filter, User user);
DictConfig getDictInfoByDomainId(Long domainId);
DictConfig getDictInfoByModelId(Long modelId);
String getDictRootPath();
}

View File

@@ -2,9 +2,8 @@ package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
import com.tencent.supersonic.chat.plugin.Plugin;
import java.util.List;
import java.util.Optional;
@@ -24,5 +23,5 @@ public interface PluginService {
Optional<Plugin> getPluginByName(String name);
List<Plugin> queryWithAuthCheck(PluginQueryReq pluginQueryReq);
List<Plugin> queryWithAuthCheck(PluginQueryReq pluginQueryReq, User user);
}

View File

@@ -3,10 +3,10 @@ package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import org.apache.calcite.sql.parser.SqlParseException;
/***

View File

@@ -4,7 +4,6 @@ package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
import java.util.List;
/***
@@ -16,5 +15,5 @@ public interface RecommendService {
RecommendResp recommendMetricMode(QueryReq queryCtx, Long limit);
List<RecommendQuestionResp> recommendQuestion(Long domainId);
List<RecommendQuestionResp> recommendQuestion(Long modelId);
}

View File

@@ -13,7 +13,7 @@ import static com.tencent.supersonic.common.pojo.Constants.WEEK;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.DomainSchema;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
@@ -26,9 +26,9 @@ import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.DataInfo;
import com.tencent.supersonic.chat.api.pojo.response.DomainInfo;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.MetricInfo;
import com.tencent.supersonic.chat.api.pojo.response.ModelInfo;
import com.tencent.supersonic.chat.config.AggregatorConfig;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
@@ -62,6 +62,7 @@ import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -79,47 +80,47 @@ public class SemanticService {
private SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
public DomainSchema getDomainSchema(Long id) {
DomainSchema domainSchema = schemaService.getDomainSchema(id);
if (!Objects.isNull(domainSchema) && !Objects.isNull(domainSchema.getDomain())) {
public ModelSchema getModelSchema(Long id) {
ModelSchema ModelSchema = schemaService.getModelSchema(id);
if (!Objects.isNull(ModelSchema) && !Objects.isNull(ModelSchema.getModel())) {
ChatConfigResp chaConfigInfo =
configService.fetchConfigByDomainId(domainSchema.getDomain().getId());
configService.fetchConfigByModelId(ModelSchema.getModel().getId());
// filter dimensions in blacklist
filterBlackDim(domainSchema, chaConfigInfo);
filterBlackDim(ModelSchema, chaConfigInfo);
// filter metrics in blacklist
filterBlackMetric(domainSchema, chaConfigInfo);
filterBlackMetric(ModelSchema, chaConfigInfo);
}
return domainSchema;
return ModelSchema;
}
public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, User user) {
if (parseInfo != null && parseInfo.getDomainId() > 0) {
EntityInfo entityInfo = getEntityInfo(parseInfo.getDomainId());
if (parseInfo != null && parseInfo.getModelId() > 0) {
EntityInfo entityInfo = getEntityInfo(parseInfo.getModelId());
if (parseInfo.getDimensionFilters().size() <= 0) {
entityInfo.setMetrics(null);
entityInfo.setDimensions(null);
return entityInfo;
}
if (entityInfo.getDomainInfo() != null && entityInfo.getDomainInfo().getPrimaryEntityBizName() != null) {
String domainInfoPrimaryName = entityInfo.getDomainInfo().getPrimaryEntityBizName();
String domainInfoId = "";
if (entityInfo.getModelInfo() != null && entityInfo.getModelInfo().getPrimaryEntityBizName() != null) {
String ModelInfoPrimaryName = entityInfo.getModelInfo().getPrimaryEntityBizName();
String ModelInfoId = "";
for (QueryFilter chatFilter : parseInfo.getDimensionFilters()) {
if (chatFilter != null && chatFilter.getBizName() != null && chatFilter.getBizName()
.equals(domainInfoPrimaryName)) {
.equals(ModelInfoPrimaryName)) {
if (chatFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) {
domainInfoId = chatFilter.getValue().toString();
ModelInfoId = chatFilter.getValue().toString();
}
}
}
if (!"".equals(domainInfoId)) {
if (!"".equals(ModelInfoId)) {
try {
setMainDomain(entityInfo, parseInfo.getDomainId(),
domainInfoId, user);
setMainModel(entityInfo, parseInfo.getModelId(),
ModelInfoId, user);
return entityInfo;
} catch (Exception e) {
log.error("setMaintDomain error {}", e);
log.error("setMaintModel error {}", e);
}
}
}
@@ -127,8 +128,8 @@ public class SemanticService {
return null;
}
public EntityInfo getEntityInfo(Long domain) {
ChatConfigRichResp chaConfigRichDesc = configService.getConfigRichInfo(domain);
public EntityInfo getEntityInfo(Long Model) {
ChatConfigRichResp chaConfigRichDesc = configService.getConfigRichInfo(Model);
if (Objects.isNull(chaConfigRichDesc) || Objects.isNull(chaConfigRichDesc.getChatDetailRichConfig())) {
return new EntityInfo();
}
@@ -138,23 +139,23 @@ public class SemanticService {
private EntityInfo getEntityInfo(ChatConfigRichResp chaConfigRichDesc) {
EntityInfo entityInfo = new EntityInfo();
Long domainId = chaConfigRichDesc.getDomainId();
if (Objects.nonNull(chaConfigRichDesc) && Objects.nonNull(domainId)) {
Long modelId = chaConfigRichDesc.getModelId();
if (Objects.nonNull(chaConfigRichDesc) && Objects.nonNull(modelId)) {
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = schemaService.getDomainSchema(domainId);
if (Objects.isNull(domainSchema) || Objects.isNull(domainSchema.getEntity())) {
ModelSchema ModelSchema = schemaService.getModelSchema(modelId);
if (Objects.isNull(ModelSchema) || Objects.isNull(ModelSchema.getEntity())) {
return entityInfo;
}
DomainInfo domainInfo = new DomainInfo();
domainInfo.setItemId(domainId.intValue());
domainInfo.setName(domainSchema.getDomain().getName());
domainInfo.setWords(domainSchema.getDomain().getAlias());
domainInfo.setBizName(domainSchema.getDomain().getBizName());
if (Objects.nonNull(domainSchema.getEntity())) {
domainInfo.setPrimaryEntityBizName(domainSchema.getEntity().getBizName());
ModelInfo ModelInfo = new ModelInfo();
ModelInfo.setItemId(modelId.intValue());
ModelInfo.setName(ModelSchema.getModel().getName());
ModelInfo.setWords(ModelSchema.getModel().getAlias());
ModelInfo.setBizName(ModelSchema.getModel().getBizName());
if (Objects.nonNull(ModelSchema.getEntity())) {
ModelInfo.setPrimaryEntityBizName(ModelSchema.getEntity().getBizName());
}
entityInfo.setDomainInfo(domainInfo);
entityInfo.setModelInfo(ModelInfo);
List<DataInfo> dimensions = new ArrayList<>();
List<DataInfo> metrics = new ArrayList<>();
@@ -188,19 +189,19 @@ public class SemanticService {
return entityInfo;
}
public void setMainDomain(EntityInfo domainInfo, Long domain, String entity, User user) {
DomainSchema domainSchema = schemaService.getDomainSchema(domain);
public void setMainModel(EntityInfo ModelInfo, Long Model, String entity, User user) {
ModelSchema ModelSchema = schemaService.getModelSchema(Model);
domainInfo.setEntityId(entity);
ModelInfo.setEntityId(entity);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setDomain(domainSchema.getDomain());
semanticParseInfo.setModel(ModelSchema.getModel());
semanticParseInfo.setNativeQuery(true);
semanticParseInfo.setMetrics(getMetrics(domainInfo));
semanticParseInfo.setDimensions(getDimensions(domainInfo));
semanticParseInfo.setMetrics(getMetrics(ModelInfo));
semanticParseInfo.setDimensions(getDimensions(ModelInfo));
DateConf dateInfo = new DateConf();
int unit = 1;
ChatConfigResp chatConfigInfo =
configService.fetchConfigByDomainId(domainSchema.getDomain().getId());
configService.fetchConfigByModelId(ModelSchema.getModel().getId());
if (Objects.nonNull(chatConfigInfo) && Objects.nonNull(chatConfigInfo.getChatDetailConfig())
&& Objects.nonNull(chatConfigInfo.getChatDetailConfig().getChatDefaultConfig())) {
ChatDefaultConfigReq chatDefaultConfig = chatConfigInfo.getChatDetailConfig().getChatDefaultConfig();
@@ -219,7 +220,7 @@ public class SemanticService {
QueryFilter chatFilter = new QueryFilter();
chatFilter.setValue(String.valueOf(entity));
chatFilter.setOperator(FilterOperatorEnum.EQUALS);
chatFilter.setBizName(getEntityPrimaryName(domainInfo));
chatFilter.setBizName(getEntityPrimaryName(ModelInfo));
Set<QueryFilter> chatFilters = new LinkedHashSet();
chatFilters.add(chatFilter);
semanticParseInfo.setDimensionFilters(chatFilters);
@@ -229,7 +230,7 @@ public class SemanticService {
queryResultWithColumns = semanticLayer.queryByStruct(QueryReqBuilder.buildStructReq(semanticParseInfo),
user);
} catch (Exception e) {
log.warn("setMainDomain queryByStruct error, e:", e);
log.warn("setMainModel queryByStruct error, e:", e);
}
if (queryResultWithColumns != null) {
@@ -241,18 +242,18 @@ public class SemanticService {
if (entry.getValue() == null || entryKey == null) {
continue;
}
domainInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName()))
ModelInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
domainInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName()))
ModelInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName()))
.forEach(i -> i.setValue(entry.getValue().toString()));
}
}
}
}
private Set<SchemaElement> getDimensions(EntityInfo domainInfo) {
private Set<SchemaElement> getDimensions(EntityInfo ModelInfo) {
Set<SchemaElement> dimensions = new LinkedHashSet();
for (DataInfo mainEntityDimension : domainInfo.getDimensions()) {
for (DataInfo mainEntityDimension : ModelInfo.getDimensions()) {
SchemaElement dimension = new SchemaElement();
dimension.setBizName(mainEntityDimension.getBizName());
dimensions.add(dimension);
@@ -269,41 +270,41 @@ public class SemanticService {
return entryKey;
}
private Set<SchemaElement> getMetrics(EntityInfo domainInfo) {
private Set<SchemaElement> getMetrics(EntityInfo ModelInfo) {
Set<SchemaElement> metrics = new LinkedHashSet();
for (DataInfo metricValue : domainInfo.getMetrics()) {
for (DataInfo metricValue : ModelInfo.getMetrics()) {
SchemaElement metric = new SchemaElement();
metric.setBizName(metricValue.getBizName());
BeanUtils.copyProperties(metricValue, metric);
metrics.add(metric);
}
return metrics;
}
private String getEntityPrimaryName(EntityInfo domainInfo) {
return domainInfo.getDomainInfo().getPrimaryEntityBizName();
private String getEntityPrimaryName(EntityInfo ModelInfo) {
return ModelInfo.getModelInfo().getPrimaryEntityBizName();
}
private void filterBlackMetric(DomainSchema domainSchema, ChatConfigResp chaConfigInfo) {
private void filterBlackMetric(ModelSchema ModelSchema, ChatConfigResp chaConfigInfo) {
ItemVisibility visibility = generateFinalVisibility(chaConfigInfo);
if (Objects.nonNull(chaConfigInfo) && Objects.nonNull(visibility)
&& !CollectionUtils.isEmpty(visibility.getBlackMetricIdList())
&& !CollectionUtils.isEmpty(domainSchema.getMetrics())) {
Set<SchemaElement> metric4Chat = domainSchema.getMetrics().stream()
&& !CollectionUtils.isEmpty(ModelSchema.getMetrics())) {
Set<SchemaElement> metric4Chat = ModelSchema.getMetrics().stream()
.filter(metric -> !visibility.getBlackMetricIdList().contains(metric.getId()))
.collect(Collectors.toSet());
domainSchema.setMetrics(metric4Chat);
ModelSchema.setMetrics(metric4Chat);
}
}
private void filterBlackDim(DomainSchema domainSchema, ChatConfigResp chatConfigInfo) {
private void filterBlackDim(ModelSchema ModelSchema, ChatConfigResp chatConfigInfo) {
ItemVisibility visibility = generateFinalVisibility(chatConfigInfo);
if (Objects.nonNull(chatConfigInfo) && Objects.nonNull(visibility)
&& !CollectionUtils.isEmpty(visibility.getBlackDimIdList())
&& !CollectionUtils.isEmpty(domainSchema.getDimensions())) {
Set<SchemaElement> dim4Chat = domainSchema.getDimensions().stream()
&& !CollectionUtils.isEmpty(ModelSchema.getDimensions())) {
Set<SchemaElement> dim4Chat = ModelSchema.getDimensions().stream()
.filter(dim -> !visibility.getBlackDimIdList().contains(dim.getId()))
.collect(Collectors.toSet());
domainSchema.setDimensions(dim4Chat);
ModelSchema.setDimensions(dim4Chat);
}
}
@@ -331,7 +332,7 @@ public class SemanticService {
}
public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo,
QueryResultWithSchemaResp result) {
QueryResultWithSchemaResp result) {
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics()) || !aggregatorConfig.getEnableRatio()) {
return new AggregateInfo();
}
@@ -383,7 +384,7 @@ public class SemanticService {
}
private MetricInfo queryRatio(User user, SemanticParseInfo semanticParseInfo, SchemaElement metric,
AggOperatorEnum aggOperatorEnum, QueryResultWithSchemaResp results) {
AggOperatorEnum aggOperatorEnum, QueryResultWithSchemaResp results) {
MetricInfo metricInfo = new MetricInfo();
metricInfo.setStatistics(new HashMap<>());
QueryStructReq queryStructReq = QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum);
@@ -431,7 +432,7 @@ public class SemanticService {
}
private DateConf getRatioDateConf(AggOperatorEnum aggOperatorEnum, SemanticParseInfo semanticParseInfo,
QueryResultWithSchemaResp results) {
QueryResultWithSchemaResp results) {
String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo());
Optional<String> lastDayOp = results.getResultList().stream()
.map(r -> r.get(dateField).toString())

View File

@@ -3,23 +3,20 @@ package com.tencent.supersonic.chat.service.impl;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.QueryDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.persistence.repository.ChatRepository;
import com.tencent.supersonic.chat.service.ChatService;
import java.text.SimpleDateFormat;
import java.util.List;
import java.util.Objects;
import com.tencent.supersonic.chat.service.ChatService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Service;
@@ -41,7 +38,7 @@ public class ChatServiceImpl implements ChatService {
}
@Override
public Long getContextDomain(Integer chatId) {
public Long getContextModel(Integer chatId) {
if (Objects.isNull(chatId)) {
return null;
}
@@ -50,8 +47,8 @@ public class ChatServiceImpl implements ChatService {
return null;
}
SemanticParseInfo originalSemanticParse = chatContext.getParseInfo();
if (Objects.nonNull(originalSemanticParse) && Objects.nonNull(originalSemanticParse.getDomainId())) {
return originalSemanticParse.getDomainId();
if (Objects.nonNull(originalSemanticParse) && Objects.nonNull(originalSemanticParse.getModelId())) {
return originalSemanticParse.getModelId();
}
return null;
}

View File

@@ -3,26 +3,37 @@ package com.tencent.supersonic.chat.service.impl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.DomainSchema;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.*;
import com.tencent.supersonic.chat.api.pojo.response.*;
import com.tencent.supersonic.chat.config.*;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.Entity;
import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.ChatAggRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDetailRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.EntityRichInfoResp;
import com.tencent.supersonic.chat.api.pojo.response.ItemVisibilityInfo;
import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.persistence.repository.ChatConfigRepository;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.chat.persistence.repository.ChatConfigRepository;
import com.tencent.supersonic.chat.utils.ChatConfigHelper;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
@@ -43,24 +54,24 @@ public class ConfigServiceImpl implements ConfigService {
public ConfigServiceImpl(ChatConfigRepository chatConfigRepository,
ChatConfigHelper chatConfigHelper) {
ChatConfigHelper chatConfigHelper) {
this.chatConfigRepository = chatConfigRepository;
this.chatConfigHelper = chatConfigHelper;
}
@Override
public Long addConfig(ChatConfigBaseReq configBaseCmd, User user) {
log.info("[create domain extend] object:{}", JsonUtil.toString(configBaseCmd, true));
duplicateCheck(configBaseCmd.getDomainId());
permissionCheckLogic(configBaseCmd.getDomainId(), user.getName());
log.info("[create model extend] object:{}", JsonUtil.toString(configBaseCmd, true));
duplicateCheck(configBaseCmd.getModelId());
permissionCheckLogic(configBaseCmd.getModelId(), user.getName());
ChatConfig chaConfig = chatConfigHelper.newChatConfig(configBaseCmd, user);
Long id = chatConfigRepository.createConfig(chaConfig);
return id;
}
private void duplicateCheck(Long domainId) {
private void duplicateCheck(Long modelId) {
ChatConfigFilter filter = new ChatConfigFilter();
filter.setDomainId(domainId);
filter.setModelId(modelId);
List<ChatConfigResp> chaConfigDescList = chatConfigRepository.getChatConfig(filter);
if (!CollectionUtils.isEmpty(chaConfigDescList)) {
throw new RuntimeException("chat config existed, no need to add repeatedly");
@@ -70,12 +81,12 @@ public class ConfigServiceImpl implements ConfigService {
@Override
public Long editConfig(ChatConfigEditReqReq configEditCmd, User user) {
log.info("[edit domain extend] object:{}", JsonUtil.toString(configEditCmd, true));
log.info("[edit model extend] object:{}", JsonUtil.toString(configEditCmd, true));
if (Objects.isNull(configEditCmd) || Objects.isNull(configEditCmd.getId()) && Objects.isNull(
configEditCmd.getDomainId())) {
throw new RuntimeException("editConfig, id and domainId are not allowed to be empty at the same time");
configEditCmd.getModelId())) {
throw new RuntimeException("editConfig, id and modelId are not allowed to be empty at the same time");
}
permissionCheckLogic(configEditCmd.getDomainId(), user.getName());
permissionCheckLogic(configEditCmd.getModelId(), user.getName());
ChatConfig chaConfig = chatConfigHelper.editChatConfig(configEditCmd, user);
chatConfigRepository.updateConfig(chaConfig);
return configEditCmd.getId();
@@ -83,33 +94,33 @@ public class ConfigServiceImpl implements ConfigService {
/**
* domain administrators have the right to modify related configuration information.
* model administrators have the right to modify related configuration information.
*/
private Boolean permissionCheckLogic(Long domainId, String staffName) {
private Boolean permissionCheckLogic(Long modelId, String staffName) {
// todo
return true;
}
@Override
public List<ChatConfigResp> search(ChatConfigFilter filter, User user) {
log.info("[search domain extend] object:{}", JsonUtil.toString(filter, true));
log.info("[search model extend] object:{}", JsonUtil.toString(filter, true));
List<ChatConfigResp> chaConfigDescList = chatConfigRepository.getChatConfig(filter);
return chaConfigDescList;
}
@Override
public ChatConfigResp fetchConfigByDomainId(Long domainId) {
return chatConfigRepository.getConfigByDomainId(domainId);
public ChatConfigResp fetchConfigByModelId(Long modelId) {
return chatConfigRepository.getConfigByModelId(modelId);
}
private ItemVisibilityInfo fetchVisibilityDescByConfig(ItemVisibility visibility,
DomainSchema domainSchema) {
ModelSchema modelSchema) {
ItemVisibilityInfo itemVisibilityDesc = new ItemVisibilityInfo();
List<Long> dimIdAllList = chatConfigHelper.generateAllDimIdList(domainSchema);
List<Long> metricIdAllList = chatConfigHelper.generateAllMetricIdList(domainSchema);
List<Long> dimIdAllList = chatConfigHelper.generateAllDimIdList(modelSchema);
List<Long> metricIdAllList = chatConfigHelper.generateAllMetricIdList(modelSchema);
List<Long> blackDimIdList = new ArrayList<>();
List<Long> blackMetricIdList = new ArrayList<>();
@@ -131,91 +142,98 @@ public class ConfigServiceImpl implements ConfigService {
itemVisibilityDesc.setBlackDimIdList(blackDimIdList);
itemVisibilityDesc.setBlackMetricIdList(blackMetricIdList);
itemVisibilityDesc.setWhiteDimIdList(Objects.isNull(whiteDimIdList) ? new ArrayList<>() : whiteDimIdList);
itemVisibilityDesc.setWhiteMetricIdList(Objects.isNull(whiteMetricIdList) ? new ArrayList<>() : whiteMetricIdList);
itemVisibilityDesc.setWhiteMetricIdList(
Objects.isNull(whiteMetricIdList) ? new ArrayList<>() : whiteMetricIdList);
return itemVisibilityDesc;
}
@Override
public ChatConfigRichResp getConfigRichInfo(Long domainId) {
public ChatConfigRichResp getConfigRichInfo(Long modelId) {
ChatConfigRichResp chatConfigRich = new ChatConfigRichResp();
ChatConfigResp chatConfigResp = chatConfigRepository.getConfigByDomainId(domainId);
ChatConfigResp chatConfigResp = chatConfigRepository.getConfigByModelId(modelId);
if (Objects.isNull(chatConfigResp)) {
log.info("there is no chatConfigDesc for domainId:{}", domainId);
log.info("there is no chatConfigDesc for modelId:{}", modelId);
return chatConfigRich;
}
BeanUtils.copyProperties(chatConfigResp, chatConfigRich);
DomainSchema domainSchema = semanticService.getDomainSchema(domainId);
chatConfigRich.setBizName(domainSchema.getDomain().getBizName());
chatConfigRich.setDomainName(domainSchema.getDomain().getName());
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
chatConfigRich.setBizName(modelSchema.getModel().getBizName());
chatConfigRich.setModelName(modelSchema.getModel().getName());
chatConfigRich.setChatAggRichConfig(fillChatAggRichConfig(domainSchema, chatConfigResp));
chatConfigRich.setChatDetailRichConfig(fillChatDetailRichConfig(domainSchema, chatConfigRich, chatConfigResp));
chatConfigRich.setChatAggRichConfig(fillChatAggRichConfig(modelSchema, chatConfigResp));
chatConfigRich.setChatDetailRichConfig(fillChatDetailRichConfig(modelSchema, chatConfigRich, chatConfigResp));
return chatConfigRich;
}
private ChatDetailRichConfigResp fillChatDetailRichConfig(DomainSchema domainSchema, ChatConfigRichResp chatConfigRich, ChatConfigResp chatConfigResp) {
private ChatDetailRichConfigResp fillChatDetailRichConfig(ModelSchema modelSchema,
ChatConfigRichResp chatConfigRich, ChatConfigResp chatConfigResp) {
if (Objects.isNull(chatConfigResp) || Objects.isNull(chatConfigResp.getChatDetailConfig())) {
return null;
}
ChatDetailRichConfigResp detailRichConfig = new ChatDetailRichConfigResp();
ChatDetailConfigReq chatDetailConfig = chatConfigResp.getChatDetailConfig();
ItemVisibilityInfo itemVisibilityInfo = fetchVisibilityDescByConfig(chatDetailConfig.getVisibility(), domainSchema);
ItemVisibilityInfo itemVisibilityInfo = fetchVisibilityDescByConfig(chatDetailConfig.getVisibility(),
modelSchema);
detailRichConfig.setVisibility(itemVisibilityInfo);
detailRichConfig.setKnowledgeInfos(fillKnowledgeBizName(chatDetailConfig.getKnowledgeInfos(), domainSchema));
detailRichConfig.setKnowledgeInfos(fillKnowledgeBizName(chatDetailConfig.getKnowledgeInfos(), modelSchema));
detailRichConfig.setGlobalKnowledgeConfig(chatDetailConfig.getGlobalKnowledgeConfig());
detailRichConfig.setChatDefaultConfig(fetchDefaultConfig(chatDetailConfig.getChatDefaultConfig(), domainSchema, itemVisibilityInfo));
detailRichConfig.setChatDefaultConfig(
fetchDefaultConfig(chatDetailConfig.getChatDefaultConfig(), modelSchema, itemVisibilityInfo));
return detailRichConfig;
}
private EntityRichInfoResp generateRichEntity(Entity entity, DomainSchema domainSchema) {
private EntityRichInfoResp generateRichEntity(Entity entity, ModelSchema modelSchema) {
EntityRichInfoResp entityRichInfo = new EntityRichInfoResp();
if (Objects.isNull(entity) || Objects.isNull(entity.getEntityId())) {
return entityRichInfo;
}
BeanUtils.copyProperties(entity, entityRichInfo);
Map<Long, SchemaElement> dimIdAndRespPair = domainSchema.getDimensions().stream()
Map<Long, SchemaElement> dimIdAndRespPair = modelSchema.getDimensions().stream()
.collect(Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
entityRichInfo.setDimItem(dimIdAndRespPair.get(entity.getEntityId()));
return entityRichInfo;
}
private ChatAggRichConfigResp fillChatAggRichConfig(DomainSchema domainSchema, ChatConfigResp chatConfigResp) {
private ChatAggRichConfigResp fillChatAggRichConfig(ModelSchema modelSchema, ChatConfigResp chatConfigResp) {
if (Objects.isNull(chatConfigResp) || Objects.isNull(chatConfigResp.getChatAggConfig())) {
return null;
}
ChatAggConfigReq chatAggConfig = chatConfigResp.getChatAggConfig();
ChatAggRichConfigResp chatAggRichConfig = new ChatAggRichConfigResp();
ItemVisibilityInfo itemVisibilityInfo = fetchVisibilityDescByConfig(chatAggConfig.getVisibility(), domainSchema);
ItemVisibilityInfo itemVisibilityInfo = fetchVisibilityDescByConfig(chatAggConfig.getVisibility(), modelSchema);
chatAggRichConfig.setVisibility(itemVisibilityInfo);
chatAggRichConfig.setKnowledgeInfos(fillKnowledgeBizName(chatAggConfig.getKnowledgeInfos(), domainSchema));
chatAggRichConfig.setKnowledgeInfos(fillKnowledgeBizName(chatAggConfig.getKnowledgeInfos(), modelSchema));
chatAggRichConfig.setGlobalKnowledgeConfig(chatAggConfig.getGlobalKnowledgeConfig());
chatAggRichConfig.setChatDefaultConfig(fetchDefaultConfig(chatAggConfig.getChatDefaultConfig(), domainSchema, itemVisibilityInfo));
chatAggRichConfig.setChatDefaultConfig(
fetchDefaultConfig(chatAggConfig.getChatDefaultConfig(), modelSchema, itemVisibilityInfo));
return chatAggRichConfig;
}
private ChatDefaultRichConfigResp fetchDefaultConfig(ChatDefaultConfigReq chatDefaultConfig, DomainSchema domainSchema, ItemVisibilityInfo itemVisibilityInfo) {
private ChatDefaultRichConfigResp fetchDefaultConfig(ChatDefaultConfigReq chatDefaultConfig,
ModelSchema modelSchema, ItemVisibilityInfo itemVisibilityInfo) {
ChatDefaultRichConfigResp defaultRichConfig = new ChatDefaultRichConfigResp();
if (Objects.isNull(chatDefaultConfig)) {
return defaultRichConfig;
}
BeanUtils.copyProperties(chatDefaultConfig, defaultRichConfig);
Map<Long, SchemaElement> dimIdAndRespPair = domainSchema.getDimensions().stream()
Map<Long, SchemaElement> dimIdAndRespPair = modelSchema.getDimensions().stream()
.collect(Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
Map<Long, SchemaElement> metricIdAndRespPair = domainSchema.getMetrics().stream()
Map<Long, SchemaElement> metricIdAndRespPair = modelSchema.getMetrics().stream()
.collect(Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
List<SchemaElement> dimensions = new ArrayList<>();
List<SchemaElement> metrics = new ArrayList<>();
if (!CollectionUtils.isEmpty(chatDefaultConfig.getDimensionIds())) {
chatDefaultConfig.getDimensionIds().stream()
.filter(dimId -> dimIdAndRespPair.containsKey(dimId) && itemVisibilityInfo.getWhiteDimIdList().contains(dimId))
.filter(dimId -> dimIdAndRespPair.containsKey(dimId) && itemVisibilityInfo.getWhiteDimIdList()
.contains(dimId))
.forEach(dimId -> {
SchemaElement dimSchemaResp = dimIdAndRespPair.get(dimId);
if (Objects.nonNull(dimSchemaResp)) {
@@ -229,7 +247,8 @@ public class ConfigServiceImpl implements ConfigService {
if (!CollectionUtils.isEmpty(chatDefaultConfig.getMetricIds())) {
chatDefaultConfig.getMetricIds().stream()
.filter(metricId -> metricIdAndRespPair.containsKey(metricId) && itemVisibilityInfo.getWhiteMetricIdList().contains(metricId))
.filter(metricId -> metricIdAndRespPair.containsKey(metricId)
&& itemVisibilityInfo.getWhiteMetricIdList().contains(metricId))
.forEach(metricId -> {
SchemaElement metricSchemaResp = metricIdAndRespPair.get(metricId);
if (Objects.nonNull(metricSchemaResp)) {
@@ -247,12 +266,12 @@ public class ConfigServiceImpl implements ConfigService {
private List<KnowledgeInfoReq> fillKnowledgeBizName(List<KnowledgeInfoReq> knowledgeInfos,
DomainSchema domainSchema) {
ModelSchema modelSchema) {
if (CollectionUtils.isEmpty(knowledgeInfos)) {
return new ArrayList<>();
}
Map<Long, SchemaElement> dimIdAndRespPair = domainSchema.getDimensions().stream()
.collect(Collectors.toMap(SchemaElement::getId, Function.identity(),(k1, k2) -> k1));
Map<Long, SchemaElement> dimIdAndRespPair = modelSchema.getDimensions().stream()
.collect(Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
knowledgeInfos.stream().forEach(knowledgeInfo -> {
if (Objects.nonNull(knowledgeInfo)) {
SchemaElement dimSchemaResp = dimIdAndRespPair.get(knowledgeInfo.getItemId());
@@ -267,9 +286,9 @@ public class ConfigServiceImpl implements ConfigService {
@Override
public List<ChatConfigRichResp> getAllChatRichConfig() {
List<ChatConfigRichResp> chatConfigRichInfoList = new ArrayList<>();
List<DomainResp> domainRespList = semanticLayer.getDomainListForAdmin();
domainRespList.stream().forEach(domainResp -> {
ChatConfigRichResp chatConfigRichInfo = getConfigRichInfo(domainResp.getId());
List<ModelSchema> modelSchemas = semanticLayer.getModelSchema();
modelSchemas.stream().forEach(modelSchema -> {
ChatConfigRichResp chatConfigRichInfo = getConfigRichInfo(modelSchema.getModel().getId());
if (Objects.nonNull(chatConfigRichInfo)) {
chatConfigRichInfoList.add(chatConfigRichInfo);
}

View File

@@ -9,15 +9,15 @@ import com.tencent.supersonic.chat.utils.DictMetaHelper;
import com.tencent.supersonic.chat.utils.DictQueryHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum;
import com.tencent.supersonic.knowledge.dictionary.FileHandler;
import com.tencent.supersonic.knowledge.persistence.dataobject.DictTaskDO;
import com.tencent.supersonic.knowledge.utils.DictTaskConverter;
import com.tencent.supersonic.knowledge.dictionary.DictConfig;
import com.tencent.supersonic.knowledge.dictionary.DictTaskFilter;
import com.tencent.supersonic.knowledge.dictionary.DictUpdateMode;
import com.tencent.supersonic.knowledge.dictionary.DimValue2DictCommand;
import com.tencent.supersonic.knowledge.dictionary.DimValueDictInfo;
import com.tencent.supersonic.knowledge.dictionary.FileHandler;
import com.tencent.supersonic.knowledge.persistence.dataobject.DictTaskDO;
import com.tencent.supersonic.knowledge.persistence.repository.DictRepository;
import com.tencent.supersonic.knowledge.utils.DictTaskConverter;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -41,9 +41,9 @@ public class DictionaryServiceImpl implements DictionaryService {
private String dimValue = "DimValue_%d_%d";
public DictionaryServiceImpl(DictMetaHelper metaUtils,
DictQueryHelper dictQueryHelper,
FileHandler fileHandler,
DictRepository dictRepository) {
DictQueryHelper dictQueryHelper,
FileHandler fileHandler,
DictRepository dictRepository) {
this.metaUtils = metaUtils;
this.dictQueryHelper = dictQueryHelper;
this.fileHandler = fileHandler;
@@ -65,12 +65,12 @@ public class DictionaryServiceImpl implements DictionaryService {
log.info("dimValueDOList:{}", dimValueDOList);
//2. query dimension value information
for (DimValueDO dimValueDO : dimValueDOList) {
Long domainId = dimValueDO.getDomainId();
Long modelId = dimValueDO.getModelId();
DefaultMetric defaultMetricDesc = dimValueDO.getDefaultMetricDescList().get(0);
for (Dim4Dict dim4Dict : dimValueDO.getDimensions()) {
List<String> data = dictQueryHelper.fetchDimValueSingle(domainId, defaultMetricDesc, dim4Dict, user);
List<String> data = dictQueryHelper.fetchDimValueSingle(modelId, defaultMetricDesc, dim4Dict, user);
//3. local file changes
String fileName = String.format(dimValue + Constants.DOT + dictFileType, domainId,
String fileName = String.format(dimValue + Constants.DOT + dictFileType, modelId,
dim4Dict.getDimId());
fileHandler.writeFile(data, fileName, false);
}
@@ -93,16 +93,16 @@ public class DictionaryServiceImpl implements DictionaryService {
dimValue2DictCommend.getUpdateMode())) {
throw new RuntimeException("illegal parameter");
}
Map<Long, List<Long>> domainAndDimPair = dimValue2DictCommend.getDomainAndDimPair();
if (CollectionUtils.isEmpty(domainAndDimPair)) {
Map<Long, List<Long>> modelAndDimPair = dimValue2DictCommend.getModelAndDimPair();
if (CollectionUtils.isEmpty(modelAndDimPair)) {
return 0L;
}
for (Long domainId : domainAndDimPair.keySet()) {
if (CollectionUtils.isEmpty(domainAndDimPair.get(domainId))) {
for (Long modelId : modelAndDimPair.keySet()) {
if (CollectionUtils.isEmpty(modelAndDimPair.get(modelId))) {
continue;
}
for (Long dimId : domainAndDimPair.get(domainId)) {
String fileName = String.format(dimValue + Constants.DOT + dictFileType, domainId, dimId);
for (Long dimId : modelAndDimPair.get(modelId)) {
String fileName = String.format(dimValue + Constants.DOT + dictFileType, modelId, dimId);
fileHandler.deleteDictFile(fileName);
}
}
@@ -118,7 +118,7 @@ public class DictionaryServiceImpl implements DictionaryService {
return dictRepository.searchDictTaskList(filter);
}
public DictConfig getDictInfoByDomainId(Long domainId) {
return dictRepository.getDictInfoByDomainId(domainId);
public DictConfig getDictInfoByModelId(Long modelId) {
return dictRepository.getDictInfoByModelId(modelId);
}
}

View File

@@ -3,20 +3,26 @@ package com.tencent.supersonic.chat.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDO;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
import com.tencent.supersonic.chat.persistence.repository.PluginRepository;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
@@ -24,9 +30,6 @@ import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.stream.Collectors;
@Service
@Slf4j
public class PluginServiceImpl implements PluginService {
@@ -36,22 +39,22 @@ public class PluginServiceImpl implements PluginService {
private ApplicationEventPublisher publisher;
public PluginServiceImpl(PluginRepository pluginRepository,
ApplicationEventPublisher publisher) {
ApplicationEventPublisher publisher) {
this.pluginRepository = pluginRepository;
this.publisher = publisher;
}
@Override
public synchronized void createPlugin(Plugin plugin, User user){
public synchronized void createPlugin(Plugin plugin, User user) {
PluginDO pluginDO = convert(plugin, user);
pluginRepository.createPlugin(pluginDO);
//compatible with H2 db
List<Plugin> plugins = getPluginList();
publisher.publishEvent(new PluginAddEvent(this, plugins.get(plugins.size()-1)));
publisher.publishEvent(new PluginAddEvent(this, plugins.get(plugins.size() - 1)));
}
@Override
public void updatePlugin(Plugin plugin, User user){
public void updatePlugin(Plugin plugin, User user) {
Long id = plugin.getId();
PluginDO pluginDO = pluginRepository.getPlugin(id);
Plugin oldPlugin = convert(pluginDO);
@@ -61,7 +64,7 @@ public class PluginServiceImpl implements PluginService {
}
@Override
public void deletePlugin(Long id){
public void deletePlugin(Long id) {
PluginDO pluginDO = pluginRepository.getPlugin(id);
if (pluginDO != null) {
pluginRepository.deletePlugin(id);
@@ -74,7 +77,7 @@ public class PluginServiceImpl implements PluginService {
public List<Plugin> getPluginList() {
List<Plugin> plugins = Lists.newArrayList();
List<PluginDO> pluginDOS = pluginRepository.getPlugins();
if(CollectionUtils.isEmpty(pluginDOS)){
if (CollectionUtils.isEmpty(pluginDOS)) {
return plugins;
}
return pluginDOS.stream().map(this::convert).collect(Collectors.toList());
@@ -82,7 +85,7 @@ public class PluginServiceImpl implements PluginService {
@Override
public List<Plugin> fetchPluginDOs(String queryText, String type) {
List<PluginDO> pluginDOS = pluginRepository.fetchPluginDOs(queryText,type);
List<PluginDO> pluginDOS = pluginRepository.fetchPluginDOs(queryText, type);
return convertList(pluginDOS);
}
@@ -94,8 +97,8 @@ public class PluginServiceImpl implements PluginService {
if (StringUtils.isNotBlank(pluginQueryReq.getType())) {
pluginDOExample.getOredCriteria().get(0).andTypeEqualTo(pluginQueryReq.getType());
}
if (StringUtils.isNotBlank(pluginQueryReq.getDomain())) {
pluginDOExample.getOredCriteria().get(0).andDomainLike('%' + pluginQueryReq.getDomain() + '%');
if (StringUtils.isNotBlank(pluginQueryReq.getModel())) {
pluginDOExample.getOredCriteria().get(0).andModelLike('%' + pluginQueryReq.getModel() + '%');
}
if (StringUtils.isNotBlank(pluginQueryReq.getParseMode())) {
pluginDOExample.getOredCriteria().get(0).andParseModeEqualTo(pluginQueryReq.getParseMode());
@@ -112,8 +115,8 @@ public class PluginServiceImpl implements PluginService {
List<PluginDO> pluginDOS = pluginRepository.query(pluginDOExample);
if (StringUtils.isNotBlank(pluginQueryReq.getPattern())) {
pluginDOS = pluginDOS.stream().filter(pluginDO ->
pluginDO.getPattern().contains(pluginQueryReq.getPattern()) ||
(pluginDO.getName()!=null && pluginDO.getName().contains(pluginQueryReq.getPattern())))
pluginDO.getPattern().contains(pluginQueryReq.getPattern()) ||
(pluginDO.getName() != null && pluginDO.getName().contains(pluginQueryReq.getPattern())))
.collect(Collectors.toList());
}
return convertList(pluginDOS);
@@ -127,8 +130,12 @@ public class PluginServiceImpl implements PluginService {
if (StringUtils.isBlank(plugin.getParseModeConfig())) {
return false;
}
PluginParseConfig functionCallConfig = JsonUtil.toObject(plugin.getParseModeConfig(), PluginParseConfig.class);
if (Objects.isNull(functionCallConfig)) {
PluginParseConfig functionCallConfig = JsonUtil.toObject(plugin.getParseModeConfig(),
PluginParseConfig.class);
if (Objects.isNull(functionCallConfig) || StringUtils.isEmpty(functionCallConfig.getName())) {
return false;
}
if (StringUtils.isBlank(functionCallConfig.getName())) {
return false;
}
return functionCallConfig.getName().equalsIgnoreCase(name);
@@ -137,20 +144,20 @@ public class PluginServiceImpl implements PluginService {
}
@Override
public List<Plugin> queryWithAuthCheck(PluginQueryReq pluginQueryReq) {
return authCheck(query(pluginQueryReq));
public List<Plugin> queryWithAuthCheck(PluginQueryReq pluginQueryReq, User user) {
return authCheck(query(pluginQueryReq), user);
}
private List<Plugin> authCheck(List<Plugin> plugins) {
private List<Plugin> authCheck(List<Plugin> plugins, User user) {
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
List<Long> domainIdAuthorized = semanticLayer.getDomainListForAdmin().stream()
.map(DomainResp::getId).collect(Collectors.toList());
List<Long> modelIdAuthorized = semanticLayer.getModelList(AuthType.ADMIN, null, user).stream()
.map(ModelResp::getId).collect(Collectors.toList());
plugins = plugins.stream().filter(plugin -> {
if (CollectionUtils.isEmpty(plugin.getDomainList()) || plugin.isContainsAllDomain()) {
if (CollectionUtils.isEmpty(plugin.getModelList()) || plugin.isContainsAllModel()) {
return true;
}
for (Long domainId : plugin.getDomainList()) {
if (domainIdAuthorized.contains(domainId)) {
for (Long modelId : plugin.getModelList()) {
if (modelIdAuthorized.contains(modelId)) {
return true;
}
}
@@ -159,37 +166,37 @@ public class PluginServiceImpl implements PluginService {
return plugins;
}
public Plugin convert(PluginDO pluginDO){
public Plugin convert(PluginDO pluginDO) {
Plugin plugin = new Plugin();
BeanUtils.copyProperties(pluginDO,plugin);
if (StringUtils.isNotBlank(pluginDO.getDomain())) {
plugin.setDomainList(Arrays.stream(pluginDO.getDomain().split(","))
BeanUtils.copyProperties(pluginDO, plugin);
if (StringUtils.isNotBlank(pluginDO.getModel())) {
plugin.setModelList(Arrays.stream(pluginDO.getModel().split(","))
.map(Long::parseLong).collect(Collectors.toList()));
}
return plugin;
}
public PluginDO convert(Plugin plugin, User user){
public PluginDO convert(Plugin plugin, User user) {
PluginDO pluginDO = new PluginDO();
BeanUtils.copyProperties(plugin,pluginDO);
BeanUtils.copyProperties(plugin, pluginDO);
pluginDO.setCreatedAt(new Date());
pluginDO.setCreatedBy(user.getName());
pluginDO.setUpdatedAt(new Date());
pluginDO.setUpdatedBy(user.getName());
pluginDO.setDomain(StringUtils.join(plugin.getDomainList(), ","));
pluginDO.setModel(StringUtils.join(plugin.getModelList(), ","));
return pluginDO;
}
public PluginDO convert(Plugin plugin, PluginDO pluginDO, User user){
BeanUtils.copyProperties(plugin,pluginDO);
public PluginDO convert(Plugin plugin, PluginDO pluginDO, User user) {
BeanUtils.copyProperties(plugin, pluginDO);
pluginDO.setUpdatedAt(new Date());
pluginDO.setUpdatedBy(user.getName());
pluginDO.setDomain(StringUtils.join(plugin.getDomainList(), ","));
pluginDO.setModel(StringUtils.join(plugin.getModelList(), ","));
return pluginDO;
}
public List<Plugin> convertList(List<PluginDO> pluginDOS){
if(!CollectionUtils.isEmpty(pluginDOS)){
public List<Plugin> convertList(List<PluginDO> pluginDOS) {
if (!CollectionUtils.isEmpty(pluginDOS)) {
return pluginDOS.stream().map(this::convert).collect(Collectors.toList());
}
return Lists.newArrayList();

View File

@@ -2,27 +2,26 @@ package com.tencent.supersonic.chat.service.impl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.*;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.QuerySelector;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import java.util.Arrays;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.List;
import java.util.stream.Collectors;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.springframework.beans.BeanUtils;
@@ -69,9 +68,9 @@ public class QueryServiceImpl implements QueryService {
Collectors.toList()));
List<SemanticParseInfo> selectedParses = selectedQueries.stream()
.map(q -> q.getParseInfo()).collect(Collectors.toList());
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
List<SemanticParseInfo> candidateParses = queryCtx.getCandidateQueries().stream()
.map(q -> q.getParseInfo()).collect(Collectors.toList());
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
parseResult = ParseResp.builder()
.chatId(queryReq.getChatId())

Some files were not shown because too many files have changed in this diff Show More