[improvement][project] supersonic 0.7.2 version backend update (#28)

Co-authored-by: jipengli <jipengli@tencent.com>
This commit is contained in:
jipeli
2023-08-15 08:56:18 +08:00
committed by GitHub
parent 27283001a8
commit b1952d64ab
461 changed files with 18548 additions and 11939 deletions

View File

@@ -2,17 +2,18 @@ package com.tencent.supersonic.chat.api.component;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.DomainSchema;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.DomainResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.api.model.response.ModelResp;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.List;
/**
@@ -30,16 +31,22 @@ import java.util.List;
public interface SemanticLayer {
QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user);
QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user);
List<DomainSchema> getDomainSchema();
List<DomainSchema> getDomainSchema(List<Long> ids);
DomainSchema getDomainSchema(Long domain, Boolean cacheEnable);
List<ModelSchema> getModelSchema();
List<ModelSchema> getModelSchema(List<Long> ids);
ModelSchema getModelSchema(Long model, Boolean cacheEnable);
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd);
PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd);
List<DomainResp> getDomainListForViewer();
List<DomainResp> getDomainListForAdmin();
List<DomainResp> getDomainList(User user);
List<ModelResp> getModelList(AuthType authType, Long domainId, User user);
}

View File

@@ -1,15 +1,14 @@
package com.tencent.supersonic.chat.api.pojo;
import lombok.Data;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import lombok.Data;
@Data
public class DomainSchema {
public class ModelSchema {
private SchemaElement domain;
private SchemaElement model;
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();
private Set<SchemaElement> dimensionValues = new HashSet<>();
@@ -22,8 +21,8 @@ public class DomainSchema {
case ENTITY:
element = Optional.ofNullable(entity);
break;
case DOMAIN:
element = Optional.of(domain);
case MODEL:
element = Optional.of(model);
break;
case METRIC:
element = metrics.stream().filter(e -> e.getId() == elementID).findFirst();

View File

@@ -2,12 +2,13 @@ package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import lombok.Data;
import java.util.ArrayList;
import java.util.List;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@NoArgsConstructor
public class QueryContext {
private QueryReq request;

View File

@@ -1,11 +1,12 @@
package com.tencent.supersonic.chat.api.pojo;
import com.google.common.base.Objects;
import java.io.Serializable;
import java.util.List;
import lombok.*;
import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;
@Data
@Getter
@@ -14,7 +15,7 @@ import lombok.*;
//@AllArgsConstructor
public class SchemaElement implements Serializable {
private Long domain;
private Long model;
private Long id;
private String name;
private String bizName;
@@ -25,9 +26,9 @@ public class SchemaElement implements Serializable {
// public SchemaElement() {
// }
public SchemaElement(Long domain, Long id, String name, String bizName,
Long useCnt, SchemaElementType type, List<String> alias) {
this.domain = domain;
public SchemaElement(Long model, Long id, String name, String bizName,
Long useCnt, SchemaElementType type, List<String> alias) {
this.model = model;
this.id = id;
this.name = name;
this.bizName = bizName;
@@ -45,7 +46,7 @@ public class SchemaElement implements Serializable {
return false;
}
SchemaElement schemaElement = (SchemaElement) o;
return Objects.equal(domain, schemaElement.domain) && Objects.equal(id,
return Objects.equal(model, schemaElement.model) && Objects.equal(id,
schemaElement.id) && Objects.equal(name, schemaElement.name)
&& Objects.equal(bizName, schemaElement.bizName) && Objects.equal(
useCnt, schemaElement.useCnt) && Objects.equal(type, schemaElement.type);
@@ -53,6 +54,6 @@ public class SchemaElement implements Serializable {
@Override
public int hashCode() {
return Objects.hashCode(domain, id, name, bizName, useCnt, type);
return Objects.hashCode(model, id, name, bizName, useCnt, type);
}
}

View File

@@ -18,10 +18,6 @@ public class SchemaElementMatch {
String detectWord;
String word;
Long frequency;
MatchMode mode = MatchMode.CURRENT;
boolean isInherited;
public enum MatchMode {
CURRENT,
INHERIT
}
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.api.pojo;
public enum SchemaElementType {
DOMAIN,
MODEL,
METRIC,
DIMENSION,
VALUE,

View File

@@ -7,21 +7,21 @@ import java.util.Set;
public class SchemaMapInfo {
private Map<Long, List<SchemaElementMatch>> domainElementMatches = new HashMap<>();
private Map<Long, List<SchemaElementMatch>> modelElementMatches = new HashMap<>();
public Set<Long> getMatchedDomains() {
return domainElementMatches.keySet();
public Set<Long> getMatchedModels() {
return modelElementMatches.keySet();
}
public List<SchemaElementMatch> getMatchedElements(Long domain) {
return domainElementMatches.get(domain);
public List<SchemaElementMatch> getMatchedElements(Long model) {
return modelElementMatches.get(model);
}
public Map<Long, List<SchemaElementMatch>> getDomainElementMatches() {
return domainElementMatches;
public Map<Long, List<SchemaElementMatch>> getModelElementMatches() {
return modelElementMatches;
}
public void setMatchedElements(Long domain, List<SchemaElementMatch> elementMatches) {
domainElementMatches.put(domain, elementMatches);
public void setMatchedElements(Long model, List<SchemaElementMatch> elementMatches) {
modelElementMatches.put(model, elementMatches);
}
}

View File

@@ -1,19 +1,25 @@
package com.tencent.supersonic.chat.api.pojo;
import java.util.*;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import lombok.Data;
@Data
public class SemanticParseInfo {
private String queryMode;
private SchemaElement domain;
private SchemaElement model;
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
private Set<SchemaElement> dimensions = new LinkedHashSet();
private SchemaElement entity;
@@ -28,15 +34,16 @@ public class SemanticParseInfo {
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
private Map<String, Object> properties = new HashMap<>();
public Long getDomainId() {
return domain != null ? domain.getId() : 0L;
public Long getModelId() {
return model != null ? model.getId() : 0L;
}
public String getDomainName() {
return domain != null ? domain.getName() : "null";
public String getModelName() {
return model != null ? model.getName() : "null";
}
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
@Override
public int compare(SchemaElement o1, SchemaElement o2) {
int len1 = o1.getName().length();
@@ -49,4 +56,11 @@ public class SemanticParseInfo {
}
}
public Set<SchemaElement> getMetrics() {
Set<SchemaElement> metricSet = new TreeSet<>(new SchemaNameLengthComparator());
metricSet.addAll(metrics);
metrics = metricSet;
return metrics;
}
}

View File

@@ -7,48 +7,49 @@ import java.util.Map;
import java.util.stream.Collectors;
public class SemanticSchema implements Serializable {
private List<DomainSchema> domainSchemaList;
public SemanticSchema(List<DomainSchema> domainSchemaList) {
this.domainSchemaList = domainSchemaList;
private List<ModelSchema> modelSchemaList;
public SemanticSchema(List<ModelSchema> modelSchemaList) {
this.modelSchemaList = modelSchemaList;
}
public void add(DomainSchema schema) {
domainSchemaList.add(schema);
public void add(ModelSchema schema) {
modelSchemaList.add(schema);
}
public Map<Long, String> getDomainIdToName() {
return domainSchemaList.stream()
.collect(Collectors.toMap(a -> a.getDomain().getId(), a -> a.getDomain().getName(), (k1, k2) -> k1));
public Map<Long, String> getModelIdToName() {
return modelSchemaList.stream()
.collect(Collectors.toMap(a -> a.getModel().getId(), a -> a.getModel().getName(), (k1, k2) -> k1));
}
public List<SchemaElement> getDimensionValues() {
List<SchemaElement> dimensionValues = new ArrayList<>();
domainSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
modelSchemaList.stream().forEach(d -> dimensionValues.addAll(d.getDimensionValues()));
return dimensionValues;
}
public List<SchemaElement> getDimensions() {
List<SchemaElement> dimensions = new ArrayList<>();
domainSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
modelSchemaList.stream().forEach(d -> dimensions.addAll(d.getDimensions()));
return dimensions;
}
public List<SchemaElement> getMetrics() {
List<SchemaElement> metrics = new ArrayList<>();
domainSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
return metrics;
}
public List<SchemaElement> getDomains() {
List<SchemaElement> domains = new ArrayList<>();
domainSchemaList.stream().forEach(d -> domains.add(d.getDomain()));
return domains;
public List<SchemaElement> getModels() {
List<SchemaElement> models = new ArrayList<>();
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
return models;
}
public List<SchemaElement> getEntities() {
List<SchemaElement> entities = new ArrayList<>();
domainSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
modelSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
return entities;
}
}

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class ChatAggConfigReq {
@@ -13,7 +12,7 @@ public class ChatAggConfigReq {
private ItemVisibility visibility;
/**
* information about dictionary about the domain
* information about dictionary about the model
*/
private List<KnowledgeInfoReq> knowledgeInfos;

View File

@@ -1,34 +1,32 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import java.util.List;
import lombok.Data;
import lombok.ToString;
import java.util.List;
/**
* extended information command about domain
* extended information command about model
*/
@Data
@ToString
public class ChatConfigBaseReq {
private Long domainId;
private Long modelId;
/**
* the chatDetailConfig about the domain
* the chatDetailConfig about the model
*/
private ChatDetailConfigReq chatDetailConfig;
/**
* the chatAggConfig about the domain
* the chatAggConfig about the model
*/
private ChatAggConfigReq chatAggConfig;
/**
* the recommended questions about the domain
* the recommended questions about the model
*/
private List<RecommendedQuestionReq> recommendedQuestions;

View File

@@ -10,6 +10,6 @@ import lombok.NoArgsConstructor;
public class ChatConfigFilter {
private Long id;
private Long domainId;
private Long modelId;
private StatusEnum status = StatusEnum.ONLINE;
}

View File

@@ -2,10 +2,9 @@ package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.Constants;
import lombok.Data;
import java.util.ArrayList;
import java.util.List;
import lombok.Data;
@Data
public class ChatDefaultConfigReq {

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class ChatDetailConfigReq {
@@ -13,7 +12,7 @@ public class ChatDetailConfigReq {
private ItemVisibility visibility;
/**
* information about dictionary about the domain
* information about dictionary about the model
*/
private List<KnowledgeInfoReq> knowledgeInfos;

View File

@@ -7,7 +7,7 @@ import lombok.NoArgsConstructor;
import lombok.ToString;
/**
* the entity info about the domain
* the entity info about the model
*/
@Data
@AllArgsConstructor

View File

@@ -7,6 +7,7 @@ import lombok.Data;
@Data
public class ExecuteQueryReq {
private User user;
private Integer chatId;
private String queryText;

View File

@@ -1,9 +1,8 @@
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
import java.util.ArrayList;
import java.util.List;
import lombok.Data;
/**
* advanced knowledge config

View File

@@ -1,20 +1,18 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import javax.validation.constraints.NotNull;
import lombok.Data;
/**
* information about dictionary about the domain
* information about dictionary about the model
*/
@Data
public class KnowledgeInfoReq {
/**
* metricId、DimensionId、domainId
* metricId、DimensionId、modelId
*/
private Long itemId;

View File

@@ -13,7 +13,7 @@ public class PluginQueryReq {
private String type;
private String domain;
private String model;
private String pattern;

View File

@@ -1,18 +1,18 @@
package com.tencent.supersonic.chat.api.pojo.request;
import java.util.HashSet;
import java.util.Set;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.Order;
import java.util.HashSet;
import java.util.Set;
import lombok.Data;
@Data
public class QueryDataReq {
String queryMode;
SchemaElement domain;
SchemaElement model;
Set<SchemaElement> metrics = new HashSet<>();
Set<SchemaElement> dimensions = new HashSet<>();
Set<QueryFilter> dimensionFilters = new HashSet<>();

View File

@@ -1,13 +1,14 @@
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Data;
@Data
public class QueryFilters {
private List<QueryFilter> filters = new ArrayList<>();
private Map<String, Object> params = new HashMap<>();
}

View File

@@ -5,9 +5,10 @@ import lombok.Data;
@Data
public class QueryReq {
private String queryText;
private Integer chatId;
private Long domainId = 0L;
private Long modelId = 0L;
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;

View File

@@ -6,5 +6,6 @@ import lombok.Data;
@Data
public class AggregateInfo {
private List<MetricInfo> metricInfos = new ArrayList<>();
private List<MetricInfo> metricInfos = new ArrayList<>();
}

View File

@@ -2,9 +2,8 @@ package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeAdvancedConfig;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class ChatAggRichConfigResp {
@@ -15,7 +14,7 @@ public class ChatAggRichConfigResp {
private ItemVisibilityInfo visibility;
/**
* information about dictionary about the domain
* information about dictionary about the model
*/
private List<KnowledgeInfoReq> knowledgeInfos;

View File

@@ -4,10 +4,8 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import java.util.Date;
import java.util.List;
import lombok.Data;
@Data
@@ -15,7 +13,7 @@ public class ChatConfigResp {
private Long id;
private Long domainId;
private Long modelId;
private ChatDetailConfigReq chatDetailConfig;

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import java.util.Date;
import java.util.List;
import lombok.Data;
@Data
@@ -12,9 +11,9 @@ public class ChatConfigRichResp {
private Long id;
private Long domainId;
private Long modelId;
private String domainName;
private String modelName;
private String bizName;
private ChatAggRichConfigResp chatAggRichConfig;

View File

@@ -4,9 +4,8 @@ package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
import com.tencent.supersonic.common.pojo.Constants;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class ChatDefaultRichConfigResp {

View File

@@ -2,9 +2,8 @@ package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeAdvancedConfig;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class ChatDetailRichConfigResp {
@@ -15,7 +14,7 @@ public class ChatDetailRichConfigResp {
private ItemVisibilityInfo visibility;
/**
* information about dictionary about the domain
* information about dictionary about the model
*/
private List<KnowledgeInfoReq> knowledgeInfos;

View File

@@ -7,7 +7,7 @@ import lombok.Data;
@Data
public class EntityInfo {
private DomainInfo domainInfo = new DomainInfo();
private ModelInfo modelInfo = new ModelInfo();
private List<DataInfo> dimensions = new ArrayList<>();
private List<DataInfo> metrics = new ArrayList<>();
private String entityId;

View File

@@ -1,14 +1,14 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import java.util.List;
import lombok.Data;
@Data
public class EntityRichInfoResp {
/**
* entity alias
* entity alias
*/
private List<String> names;

View File

@@ -5,7 +5,7 @@ import java.util.List;
import lombok.Data;
@Data
public class DomainInfo extends DataInfo implements Serializable {
public class ModelInfo extends DataInfo implements Serializable {
private List<String> words;
private String primaryEntityBizName;

View File

@@ -1,9 +1,12 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.*;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.Getter;
import lombok.NoArgsConstructor;
@Data
@Getter
@@ -11,6 +14,7 @@ import java.util.List;
@NoArgsConstructor
@AllArgsConstructor
public class ParseResp {
private Integer chatId;
private String queryText;
private ParseState state;

View File

@@ -1,14 +1,14 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Data;
import java.util.List;
@Data
@AllArgsConstructor
public class RecommendQuestionResp {
private Long domainId;
private Long modelId;
private List<RecommendedQuestionReq> recommendedQuestions;
}

View File

@@ -1,12 +1,12 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class RecommendResp {
private List<SchemaElement> dimensions;
private List<SchemaElement> metrics;
}

View File

@@ -2,8 +2,6 @@ package com.tencent.supersonic.chat.api.pojo.response;
import java.util.List;
import lombok.Data;
import lombok.Getter;
import lombok.Setter;
@Data
public class SearchResp {

View File

@@ -17,9 +17,9 @@ public class SearchResult {
private String subRecommend;
private String domainName;
private String modelName;
private Long domainId;
private Long modelId;
private SchemaElementType schemaElementType;
@@ -34,12 +34,12 @@ public class SearchResult {
return false;
}
SearchResult searchResult1 = (SearchResult) o;
return Objects.equals(recommend, searchResult1.recommend) && Objects.equals(domainName,
searchResult1.domainName);
return Objects.equals(recommend, searchResult1.recommend) && Objects.equals(modelName,
searchResult1.modelName);
}
@Override
public int hashCode() {
return Objects.hash(recommend, domainName);
return Objects.hash(recommend, modelName);
}
}

View File

@@ -1,7 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>chat</artifactId>
<groupId>com.tencent.supersonic</groupId>

View File

@@ -3,11 +3,13 @@ package com.tencent.supersonic.chat.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class AggregatorConfig {
@Value("${metric.aggregator.ratio.enable:true}")
private Boolean enableRatio;
@Value("${metric.aggregator.ratio.enable:true}")
private Boolean enableRatio;
}

View File

@@ -3,13 +3,12 @@ package com.tencent.supersonic.chat.config;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import java.util.List;
import lombok.Data;
import lombok.ToString;
import java.util.List;
@Data
@ToString
public class ChatConfig {
@@ -19,15 +18,15 @@ public class ChatConfig {
*/
private Long id;
private Long domainId;
private Long modelId;
/**
* the chatDetailConfig about the domain
* the chatDetailConfig about the model
*/
private ChatDetailConfigReq chatDetailConfig;
/**
* the chatAggConfig about the domain
* the chatAggConfig about the model
*/
private ChatAggConfigReq chatAggConfig;

View File

@@ -6,6 +6,6 @@ import lombok.Data;
public class ChatConfigFilterInternal {
private Long id;
private Long domainId;
private Long modelId;
private Integer status;
}

View File

@@ -6,7 +6,7 @@ import lombok.ToString;
/**
* default metrics about the domain
* default metrics about the model
*/
@ToString

View File

@@ -7,6 +7,7 @@ import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class FunctionCallInfoConfig {
@Value("${functionCall.url:}")
private String url;
}

View File

@@ -1,21 +0,0 @@
package com.tencent.supersonic.chat.mapper;
import java.io.Serializable;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
@Builder
public class DomainInfoStat implements Serializable {
private long domainCount;
private long metricDomainCount;
private long dimensionDomainCount;
private long dimensionValueDomainCount;
}

View File

@@ -2,14 +2,19 @@ package com.tencent.supersonic.chat.mapper;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
@@ -18,20 +23,20 @@ public class EntityMapper implements SchemaMapper {
@Override
public void map(QueryContext queryContext) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
for (Long domainId : schemaMapInfo.getMatchedDomains()) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(domainId);
for (Long modelId : schemaMapInfo.getMatchedModels()) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
continue;
}
SchemaElement entity = getEntity(domainId);
SchemaElement entity = getEntity(modelId);
if (entity == null || entity.getId() == null) {
continue;
}
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (!entity.getId().equals(schemaElementMatch.getElement().getId())){
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
continue;
}
if (!checkExistSameEntitySchemaElements(schemaElementMatch, schemaElementMatchList)) {
@@ -46,7 +51,7 @@ public class EntityMapper implements SchemaMapper {
}
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
List<SchemaElementMatch> schemaElementMatchList) {
List<SchemaElementMatch> schemaElementMatchList) {
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
@@ -58,11 +63,11 @@ public class EntityMapper implements SchemaMapper {
return false;
}
private SchemaElement getEntity(Long domainId) {
private SchemaElement getEntity(Long modelId) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = semanticService.getDomainSchema(domainId);
if (domainSchema != null && domainSchema.getEntity() != null) {
return domainSchema.getEntity();
ModelSchema modelSchema = semanticService.getModelSchema(modelId);
if (modelSchema != null && modelSchema.getEntity() != null) {
return modelSchema.getEntity();
}
return null;
}

View File

@@ -2,10 +2,14 @@ package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import java.util.ArrayList;
import java.util.Comparator;
@@ -39,13 +43,13 @@ public class FuzzyNameMapper implements SchemaMapper {
log.debug("after db mapper,mapInfo:{}", queryContext.getMapInfo());
}
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> domains,
SchemaElementType schemaElementType) {
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> Models,
SchemaElementType schemaElementType) {
try {
Map<String, Set<SchemaElement>> domainResultSet = getResultSet(queryContext, terms, domains);
Map<String, Set<SchemaElement>> ModelResultSet = getResultSet(queryContext, terms, Models);
addToSchemaMapInfo(domainResultSet, queryContext.getMapInfo(), schemaElementType);
addToSchemaMapInfo(ModelResultSet, queryContext.getMapInfo(), schemaElementType);
} catch (Exception e) {
log.error("detectAndAddToSchema error", e);
@@ -53,7 +57,7 @@ public class FuzzyNameMapper implements SchemaMapper {
}
private Map<String, Set<SchemaElement>> getResultSet(QueryContext queryContext, List<Term> terms,
List<SchemaElement> domains) {
List<SchemaElement> Models) {
String queryText = queryContext.getRequest().getQueryText();
@@ -61,12 +65,12 @@ public class FuzzyNameMapper implements SchemaMapper {
Double metricDimensionThresholdConfig = getThreshold(queryContext, mapperHelper);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(domains);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(Models);
Map<Integer, Integer> regOffsetToLength = terms.stream().sorted(Comparator.comparing(Term::length))
.collect(Collectors.toMap(Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
Map<String, Set<SchemaElement>> domainResultSet = new HashMap<>();
Map<String, Set<SchemaElement>> ModelResultSet = new HashMap<>();
for (Integer startIndex = 0; startIndex <= queryText.length() - 1; ) {
for (Integer endIndex = startIndex; endIndex <= queryText.length(); ) {
endIndex = mapperHelper.getStepIndex(regOffsetToLength, endIndex);
@@ -82,7 +86,7 @@ public class FuzzyNameMapper implements SchemaMapper {
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
continue;
}
Set<SchemaElement> preSchemaElements = domainResultSet.putIfAbsent(detectSegment,
Set<SchemaElement> preSchemaElements = ModelResultSet.putIfAbsent(detectSegment,
schemaElements);
if (Objects.nonNull(preSchemaElements)) {
preSchemaElements.addAll(schemaElements);
@@ -91,7 +95,7 @@ public class FuzzyNameMapper implements SchemaMapper {
}
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
}
return domainResultSet;
return ModelResultSet;
}
private Double getThreshold(QueryContext queryContext, MapperHelper mapperHelper) {
@@ -99,9 +103,9 @@ public class FuzzyNameMapper implements SchemaMapper {
Double metricDimensionThresholdConfig = mapperHelper.getMetricDimensionThresholdConfig();
Double metricDimensionMinThresholdConfig = mapperHelper.getMetricDimensionMinThresholdConfig();
Map<Long, List<SchemaElementMatch>> domainElementMatches = queryContext.getMapInfo()
.getDomainElementMatches();
boolean existElement = domainElementMatches.entrySet().stream()
Map<Long, List<SchemaElementMatch>> ModelElementMatches = queryContext.getMapInfo()
.getModelElementMatches();
boolean existElement = ModelElementMatches.entrySet().stream()
.anyMatch(entry -> entry.getValue().size() >= 1);
if (!existElement) {
@@ -109,14 +113,14 @@ public class FuzzyNameMapper implements SchemaMapper {
metricDimensionThresholdConfig = halfThreshold >= metricDimensionMinThresholdConfig ? halfThreshold
: metricDimensionMinThresholdConfig;
log.info("domainElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
domainElementMatches, metricDimensionThresholdConfig);
log.info("ModelElementMatches:{} , not exist Element metricDimensionThresholdConfig reduce by half:{}",
ModelElementMatches, metricDimensionThresholdConfig);
}
return metricDimensionThresholdConfig;
}
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> domains) {
return domains.stream().collect(
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> Models) {
return Models.stream().collect(
Collectors.toMap(SchemaElement::getName, a -> {
Set<SchemaElement> result = new HashSet<>();
result.add(a);
@@ -139,10 +143,10 @@ public class FuzzyNameMapper implements SchemaMapper {
Set<SchemaElement> schemaElements = entry.getValue();
for (SchemaElement schemaElement : schemaElements) {
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDomain());
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getModel());
if (CollectionUtils.isEmpty(elements)) {
elements = new ArrayList<>();
schemaMap.setMatchedElements(schemaElement.getDomain(), elements);
schemaMap.setMatchedElements(schemaElement.getModel(), elements);
}
Set<Long> regElementSet = elements.stream()
.filter(elementMatch -> schemaElementType.equals(elementMatch.getElement().getType()))

View File

@@ -2,13 +2,18 @@ package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.NatureHelper;
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.builder.BaseWordBuilder;
import com.tencent.supersonic.knowledge.dictionary.builder.WordBuilderFactory;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import java.util.ArrayList;
@@ -32,10 +37,10 @@ public class HanlpDictMapper implements SchemaMapper {
for (Term term : terms) {
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
Long domainId = queryContext.getRequest().getDomainId();
Long modelId = queryContext.getRequest().getModelId();
QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class);
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryText, terms, domainId);
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryText, terms, modelId);
List<MapResult> matches = getMatches(matchResult);
@@ -57,8 +62,8 @@ public class HanlpDictMapper implements SchemaMapper {
for (MapResult mapResult : mapResults) {
for (String nature : mapResult.getNatures()) {
Long domainId = NatureHelper.getDomainId(nature);
if (Objects.isNull(domainId)) {
Long modelId = NatureHelper.getModelId(nature);
if (Objects.isNull(modelId)) {
continue;
}
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
@@ -67,14 +72,14 @@ public class HanlpDictMapper implements SchemaMapper {
}
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = schemaService.getDomainSchema(domainId);
ModelSchema modelSchema = schemaService.getModelSchema(modelId);
BaseWordBuilder baseWordBuilder = WordBuilderFactory.get(DictWordType.getNatureType(nature));
Long elementID = baseWordBuilder.getElementID(nature);
Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature);
SchemaElement element = domainSchema.getElement(elementType, elementID);
if(Objects.isNull(element)){
SchemaElement element = modelSchema.getElement(elementType, elementID);
if (Objects.isNull(element)) {
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
continue;
}
@@ -89,11 +94,11 @@ public class HanlpDictMapper implements SchemaMapper {
.detectWord(mapResult.getDetectWord())
.build();
Map<Long, List<SchemaElementMatch>> domainElementMatches = schemaMap.getDomainElementMatches();
List<SchemaElementMatch> schemaElementMatches = domainElementMatches.putIfAbsent(domainId,
Map<Long, List<SchemaElementMatch>> modelElementMatches = schemaMap.getModelElementMatches();
List<SchemaElementMatch> schemaElementMatches = modelElementMatches.putIfAbsent(modelId,
new ArrayList<>());
if (schemaElementMatches == null) {
schemaElementMatches = domainElementMatches.get(domainId);
schemaElementMatches = modelElementMatches.get(modelId);
}
schemaElementMatches.add(schemaElementMatch);
}

View File

@@ -83,4 +83,6 @@ public class MapperHelper {
return 1 - (double) EditDistance.compute(detectSegmentLower, matchNameLower) / Math.max(matchName.length(),
detectSegment.length());
}
}

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import java.util.List;
import java.util.Map;
@@ -11,6 +10,6 @@ import java.util.Map;
*/
public interface MatchStrategy {
Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectDomainId);
Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectModelId);
}

View File

@@ -0,0 +1,21 @@
package com.tencent.supersonic.chat.mapper;
import java.io.Serializable;
import lombok.Builder;
import lombok.Data;
import lombok.ToString;
@Data
@ToString
@Builder
public class ModelInfoStat implements Serializable {
private long modelCount;
private long metricModelCount;
private long dimensionModelCount;
private long dimensionValueModelCount;
}

View File

@@ -7,13 +7,13 @@ import lombok.ToString;
@Data
@ToString
public class DomainWithSemanticType implements Serializable {
public class ModelWithSemanticType implements Serializable {
private Long domain;
private Long model;
private SchemaElementType semanticType;
public DomainWithSemanticType(Long domain, SchemaElementType semanticType) {
this.domain = domain;
public ModelWithSemanticType(Long model, SchemaElementType semanticType) {
this.model = model;
this.semanticType = semanticType;
}
}

View File

@@ -2,15 +2,20 @@ package com.tencent.supersonic.chat.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import com.tencent.supersonic.common.pojo.Constants;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class QueryFilterMapper implements SchemaMapper {
@@ -21,35 +26,35 @@ public class QueryFilterMapper implements SchemaMapper {
@Override
public void map(QueryContext queryContext) {
QueryReq queryReq = queryContext.getRequest();
Long domainId = queryReq.getDomainId();
if (domainId == null || domainId <= 0) {
Long modelId = queryReq.getModelId();
if (modelId == null || modelId <= 0) {
return;
}
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
clearOtherSchemaElementMatch(domainId, schemaMapInfo);
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(domainId);
clearOtherSchemaElementMatch(modelId, schemaMapInfo);
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(modelId);
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
schemaMapInfo.setMatchedElements(domainId, schemaElementMatches);
schemaMapInfo.setMatchedElements(modelId, schemaElementMatches);
}
addValueSchemaElementMatch(schemaElementMatches, queryReq.getQueryFilters());
}
private void clearOtherSchemaElementMatch(Long domainId, SchemaMapInfo schemaMapInfo) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getDomainElementMatches().entrySet()) {
if (!entry.getKey().equals(domainId)) {
private void clearOtherSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getModelElementMatches().entrySet()) {
if (!entry.getKey().equals(modelId)) {
entry.getValue().clear();
}
}
}
private List<SchemaElementMatch> addValueSchemaElementMatch(List<SchemaElementMatch> candidateElementMatches,
QueryFilters queryFilter) {
QueryFilters queryFilter) {
if (queryFilter == null || CollectionUtils.isEmpty(queryFilter.getFilters())) {
return candidateElementMatches;
}
for (QueryFilter filter : queryFilter.getFilters()) {
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
continue;
}
SchemaElement element = SchemaElement.builder()
@@ -63,7 +68,7 @@ public class QueryFilterMapper implements SchemaMapper {
.frequency(FREQUENCY)
.word(String.valueOf(filter.getValue()))
.similarity(SIMILARITY)
.detectWord(filter.getName())
.detectWord(Constants.EMPTY)
.build();
candidateElementMatches.add(schemaElementMatch);
}
@@ -71,13 +76,13 @@ public class QueryFilterMapper implements SchemaMapper {
}
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
List<SchemaElementMatch> schemaElementMatches) {
List<SchemaElementMatch> schemaElementMatches) {
List<SchemaElementMatch> valueSchemaElements = schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (schemaElementMatch.getElement().getId().equals(queryFilter.getElementID())
&& schemaElementMatch.getWord().equals(String.valueOf(queryFilter.getValue()))) {
&& schemaElementMatch.getWord().equals(String.valueOf(queryFilter.getValue()))) {
return true;
}
}

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
@@ -32,7 +32,7 @@ public class QueryMatchStrategy implements MatchStrategy {
private MapperHelper mapperHelper;
@Override
public Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectDomainId) {
public Map<MatchText, List<MapResult>> match(String text, List<Term> terms, Long detectmodelId) {
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
}
@@ -43,10 +43,10 @@ public class QueryMatchStrategy implements MatchStrategy {
List<Integer> offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset))
.map(term -> term.getOffset()).collect(Collectors.toList());
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectDomainId:{}", terms,
regOffsetToLength, offsetList, detectDomainId);
log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectmodelId:{}", terms,
regOffsetToLength, offsetList, detectmodelId);
List<MapResult> detects = detect(text, regOffsetToLength, offsetList, detectDomainId);
List<MapResult> detects = detect(text, regOffsetToLength, offsetList, detectmodelId);
Map<MatchText, List<MapResult>> result = new HashMap<>();
MatchText matchText = MatchText.builder()
@@ -58,7 +58,7 @@ public class QueryMatchStrategy implements MatchStrategy {
}
private List<MapResult> detect(String text, Map<Integer, Integer> regOffsetToLength, List<Integer> offsetList,
Long detectDomainId) {
Long detectmodelId) {
List<MapResult> results = Lists.newArrayList();
for (Integer index = 0; index <= text.length() - 1; ) {
@@ -69,18 +69,44 @@ public class QueryMatchStrategy implements MatchStrategy {
int offset = mapperHelper.getStepOffset(offsetList, index);
i = mapperHelper.getStepIndex(regOffsetToLength, i);
if (i <= text.length()) {
List<MapResult> mapResults = detectByStep(text, detectDomainId, index, i, offset);
mapResultRowSet.addAll(mapResults);
List<MapResult> mapResults = detectByStep(text, detectmodelId, index, i, offset);
selectMapResultInOneRound(mapResultRowSet, mapResults);
}
}
index = mapperHelper.getStepIndex(regOffsetToLength, index);
results.addAll(mapResultRowSet);
}
return results;
}
private List<MapResult> detectByStep(String text, Long detectDomainId, Integer index, Integer i, int offset) {
private void selectMapResultInOneRound(Set<MapResult> mapResultRowSet, List<MapResult> mapResults) {
for (MapResult mapResult : mapResults) {
if (mapResultRowSet.contains(mapResult)) {
boolean isDeleted = mapResultRowSet.removeIf(
entry -> {
boolean deleted = getMapKey(mapResult).equals(getMapKey(entry))
&& entry.getDetectWord().length() < mapResult.getDetectWord().length();
if (deleted) {
log.info("deleted entry:{}", entry);
}
return deleted;
}
);
if (isDeleted) {
log.info("deleted, add mapResult:{}", mapResult);
mapResultRowSet.add(mapResult);
}
} else {
mapResultRowSet.add(mapResult);
}
}
}
private String getMapKey(MapResult a) {
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
}
private List<MapResult> detectByStep(String text, Long detectmodelId, Integer index, Integer i, int offset) {
String detectSegment = text.substring(index, i);
Integer oneDetectionSize = mapperHelper.getOneDetectionSize();
// step1. pre search
@@ -100,17 +126,17 @@ public class QueryMatchStrategy implements MatchStrategy {
mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toCollection(LinkedHashSet::new));
// step4. filter by classId
if (Objects.nonNull(detectDomainId) && detectDomainId > 0) {
log.debug("detectDomainId:{}, before parseResults:{}", mapResults);
if (Objects.nonNull(detectmodelId) && detectmodelId > 0) {
log.debug("detectmodelId:{}, before parseResults:{}", mapResults);
mapResults = mapResults.stream().map(entry -> {
List<String> natures = entry.getNatures().stream().filter(
nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectDomainId) || (nature.startsWith(
nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectmodelId) || (nature.startsWith(
DictWordType.NATURE_SPILT))
).collect(Collectors.toList());
entry.setNatures(natures);
return entry;
}).collect(Collectors.toCollection(LinkedHashSet::new));
log.info("after domainId parseResults:{}", mapResults);
log.info("after modelId parseResults:{}", mapResults);
}
// step5. filter by similarity
mapResults = mapResults.stream()

View File

@@ -2,10 +2,9 @@ package com.tencent.supersonic.chat.mapper;
import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.dictionary.DictWordType;
import com.tencent.supersonic.knowledge.dictionary.MapResult;
import com.tencent.supersonic.knowledge.service.SearchService;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -25,7 +24,7 @@ public class SearchMatchStrategy implements MatchStrategy {
@Override
public Map<MatchText, List<MapResult>> match(String text, List<Term> originals,
Long detectDomainId) {
Long detectModelId) {
Map<Integer, Integer> regOffsetToLength = originals.stream()
.filter(entry -> !entry.nature.toString().startsWith(DictWordType.NATURE_SPILT))
@@ -60,10 +59,10 @@ public class SearchMatchStrategy implements MatchStrategy {
List<String> natures = entry.getNatures().stream()
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
.filter(nature -> {
if (Objects.isNull(detectDomainId) || detectDomainId <= 0) {
if (Objects.isNull(detectModelId) || detectModelId <= 0) {
return true;
}
if (nature.startsWith(DictWordType.NATURE_SPILT + detectDomainId)
if (nature.startsWith(DictWordType.NATURE_SPILT + detectModelId)
&& nature.startsWith(DictWordType.NATURE_SPILT)) {
return true;
}

View File

@@ -2,22 +2,9 @@ package com.tencent.supersonic.chat.parser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.Lists;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* This checker can be used by semantic parsers to check if query intent
@@ -30,86 +17,20 @@ public class SatisfactionChecker {
private static final double LONG_TEXT_THRESHOLD = 0.8;
private static final double SHORT_TEXT_THRESHOLD = 0.5;
private static final int QUERY_TEXT_LENGTH_THRESHOLD = 10;
public static final double EMBEDDING_THRESHOLD = 0.2;
// check all the parse info in candidate
public static boolean check(QueryContext queryCtx) {
for (SemanticQuery query : queryCtx.getCandidateQueries()) {
if (query instanceof RuleSemanticQuery) {
if (checkRuleThreshHold(queryCtx.getRequest().getQueryText(), query.getParseInfo())) {
return true;
}
} else if (query instanceof PluginSemanticQuery) {
if (checkEmbeddingThreshold(query.getParseInfo())) {
log.info("query mode :{} satisfy check", query.getQueryMode());
return true;
}
public static boolean check(QueryContext queryContext) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) {
return true;
}
}
return false;
}
private static boolean checkEmbeddingThreshold(SemanticParseInfo semanticParseInfo) {
Object object = semanticParseInfo.getProperties().get(Constants.CONTEXT);
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(object), PluginParseResult.class);
return EMBEDDING_THRESHOLD > pluginParseResult.getDistance();
}
//check single parse info
private static boolean checkRuleThreshHold(String text, SemanticParseInfo semanticParseInfo) {
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
if (CollectionUtils.isEmpty(schemaElementMatches)) {
return false;
}
List<String> detectWords = Lists.newArrayList();
Map<Long, String> detectWordMap = schemaElementMatches.stream()
.collect(Collectors.toMap(m -> m.getElement().getId(), SchemaElementMatch::getDetectWord,
(id1, id2) -> id1));
// get detect word in text by element id in semantic layer
Long domainId = semanticParseInfo.getDomainId();
if (domainId != null && domainId > 0) {
for (SchemaElementMatch schemaElementMatch : schemaElementMatches) {
if (SchemaElementType.DOMAIN.equals(schemaElementMatch.getElement().getType())) {
detectWords.add(schemaElementMatch.getDetectWord());
}
}
}
for (SchemaElementMatch schemaElementMatch : schemaElementMatches) {
if (SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())) {
detectWords.add(schemaElementMatch.getDetectWord());
}
}
for (SchemaElementMatch schemaElementMatch : schemaElementMatches) {
if (SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())) {
detectWords.add(schemaElementMatch.getDetectWord());
}
}
for (SchemaElement schemaItem : semanticParseInfo.getMetrics()) {
detectWords.add(
detectWordMap.getOrDefault(Optional.ofNullable(schemaItem.getId()).orElse(0L), ""));
// only first metric
break;
}
for (SchemaElement schemaItem : semanticParseInfo.getDimensions()) {
detectWords.add(
detectWordMap.getOrDefault(Optional.ofNullable(schemaItem.getId()).orElse(0L), ""));
// only first dimension
break;
}
String dateText = Optional.ofNullable(semanticParseInfo.getDateInfo()).orElse(new DateConf()).getText();
if (StringUtils.isNotBlank(dateText) && !dateText.equalsIgnoreCase(Constants.NULL)) {
detectWords.add(dateText);
}
detectWords.removeIf(word -> !text.contains(word) && !text.contains(StringUtils.reverse(word)));
//compare the length between detect words and query text
return checkThreshold(text, detectWords, semanticParseInfo);
}
private static boolean checkThreshold(String queryText, List<String> detectWords, SemanticParseInfo semanticParseInfo) {
String detectWordsDistinct = StringUtils.join(new HashSet<>(detectWords), "");
int detectWordsLength = detectWordsDistinct.length();
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
int queryTextLength = queryText.length();
double degree = detectWordsLength * 1.0 / queryTextLength;
double degree = semanticParseInfo.getScore() / queryTextLength;
if (queryTextLength > QUERY_TEXT_LENGTH_THRESHOLD) {
if (degree < LONG_TEXT_THRESHOLD) {
return false;
@@ -117,8 +38,8 @@ public class SatisfactionChecker {
} else if (degree < SHORT_TEXT_THRESHOLD) {
return false;
}
log.info("queryMode:{}, degree:{}, detectWords:{}, parse info:{}",
semanticParseInfo.getQueryMode(), degree, detectWordsDistinct, semanticParseInfo);
log.info("queryMode:{}, degree:{}, parse info:{}",
semanticParseInfo.getQueryMode(), degree, semanticParseInfo);
return true;
}

View File

@@ -1,13 +1,18 @@
package com.tencent.supersonic.chat.parser.embedding;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
@@ -17,18 +22,20 @@ import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.*;
import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.util.CollectionUtils;
@Slf4j
public class EmbeddingBasedParser implements SemanticParser {
private final static double THRESHOLD = 0.2d;
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
@@ -36,109 +43,15 @@ public class EmbeddingBasedParser implements SemanticParser {
return;
}
log.info("EmbeddingBasedParser parser query ctx: {}, chat ctx: {}", queryContext, chatContext);
Set<Long> domainIds = getDomainMatched(queryContext);
String text = queryContext.getRequest().getQueryText();
if (!CollectionUtils.isEmpty(domainIds)) {
for (Long domainId : domainIds) {
List<SchemaElementMatch> schemaElementMatches = getMatchedElements(queryContext, domainId);
String textReplaced = replaceText(text, schemaElementMatches);
List<RecallRetrieval> embeddingRetrievals = recallResult(textReplaced, hasCandidateQuery(queryContext));
Optional<Plugin> pluginOptional = choosePlugin(embeddingRetrievals, domainId);
log.info("domain id :{} embedding result, text:{} embeddingResp:{} ",domainId, textReplaced, embeddingRetrievals);
pluginOptional.ifPresent(plugin -> buildQuery(plugin, embeddingRetrievals, domainId, textReplaced, queryContext, schemaElementMatches));
}
} else {
List<RecallRetrieval> embeddingRetrievals = recallResult(text, hasCandidateQuery(queryContext));
Optional<Plugin> pluginOptional = choosePlugin(embeddingRetrievals, null);
pluginOptional.ifPresent(plugin -> buildQuery(plugin, embeddingRetrievals, null, text, queryContext, Lists.newArrayList()));
}
List<RecallRetrieval> embeddingRetrievals = recallResult(text);
choosePlugin(embeddingRetrievals, queryContext);
}
private void buildQuery(Plugin plugin, List<RecallRetrieval> embeddingRetrievals,
Long domainId, String text,
QueryContext queryContext, List<SchemaElementMatch> schemaElementMatches) {
Map<String, RecallRetrieval> embeddingRetrievalMap = embeddingRetrievals.stream()
.collect(Collectors.toMap(RecallRetrieval::getId, e -> e, (value1, value2) -> value1));
log.info("EmbeddingBasedParser text: {} domain: {} choose plugin: [{} {}]",
text, domainId, plugin.getId(), plugin.getName());
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(domainId, plugin, text,
queryContext.getRequest(), embeddingRetrievalMap, schemaElementMatches);
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
pluginQuery.setParseInfo(semanticParseInfo);
queryContext.getCandidateQueries().add(pluginQuery);
}
private Set<Long> getDomainMatched(QueryContext queryContext) {
Long queryDomainId = queryContext.getRequest().getDomainId();
if (queryDomainId != null && queryDomainId > 0) {
return Sets.newHashSet(queryDomainId);
}
return queryContext.getMapInfo().getMatchedDomains();
}
private SemanticParseInfo buildSemanticParseInfo(Long domainId, Plugin plugin, String text, QueryReq queryReq,
Map<String, RecallRetrieval> embeddingRetrievalMap,
List<SchemaElementMatch> schemaElementMatches) {
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDomain(domainId);
schemaElement.setId(domainId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setDomain(schemaElement);
double distance = Double.parseDouble(embeddingRetrievalMap.get(plugin.getId().toString()).getDistance());
double score = text.length() * (1 - distance);
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);
pluginParseResult.setRequest(queryReq);
pluginParseResult.setDistance(distance);
properties.put(Constants.CONTEXT, pluginParseResult);
semanticParseInfo.setProperties(properties);
semanticParseInfo.setScore(score);
fillSemanticParseInfo(semanticParseInfo);
setEntity(domainId, semanticParseInfo);
return semanticParseInfo;
}
private List<SchemaElementMatch> getMatchedElements(QueryContext queryContext, Long domainId) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(domainId);
if (schemaElementMatches == null) {
return Lists.newArrayList();
}
QueryReq queryReq = queryContext.getRequest();
QueryFilters queryFilters = queryReq.getQueryFilters();
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return schemaElementMatches;
}
Map<Long, Object> element = queryFilters.getFilters().stream()
.collect(Collectors.toMap(QueryFilter::getElementID, QueryFilter::getValue, (v1, v2) -> v1));
return schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
.filter(schemaElementMatch ->
!element.containsKey(schemaElementMatch.getElement().getId()) || (
element.containsKey(schemaElementMatch.getElement().getId()) &&
element.get(schemaElementMatch.getElement().getId()).toString()
.equalsIgnoreCase(schemaElementMatch.getWord())
))
.collect(Collectors.toList());
}
private void setEntity(Long domainId, SemanticParseInfo semanticParseInfo) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = semanticService.getDomainSchema(domainId);
if (domainSchema != null && domainSchema.getEntity() != null) {
semanticParseInfo.setEntity(domainSchema.getEntity());
}
}
private Optional<Plugin> choosePlugin(List<RecallRetrieval> embeddingRetrievals,
Long domainId) {
private void choosePlugin(List<RecallRetrieval> embeddingRetrievals,
QueryContext queryContext) {
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return Optional.empty();
return;
}
PluginService pluginService = ContextUtils.getBean(PluginService.class);
List<Plugin> plugins = pluginService.getPluginList();
@@ -148,27 +61,92 @@ public class EmbeddingBasedParser implements SemanticParser {
if (plugin == null) {
continue;
}
if (domainId == null) {
return Optional.of(plugin);
}
if (!CollectionUtils.isEmpty(plugin.getDomainList()) && plugin.getDomainList().contains(domainId)) {
return Optional.of(plugin);
Pair<Boolean, List<Long>> pair = PluginManager.resolve(plugin, queryContext);
log.info("embedding plugin resolve: {}", pair);
if (pair.getLeft()) {
List<Long> modelList = pair.getRight();
if (CollectionUtils.isEmpty(modelList)) {
return;
}
modelList = distinctModelList(plugin, queryContext.getMapInfo(), modelList);
for (Long modelId : modelList) {
buildQuery(plugin, Double.parseDouble(embeddingRetrieval.getDistance()), modelId, queryContext,
queryContext.getMapInfo().getMatchedElements(modelId));
}
return;
}
}
return Optional.empty();
}
public List<RecallRetrieval> recallResult(String embeddingText, boolean hasCandidateQuery) {
public List<Long> distinctModelList(Plugin plugin, SchemaMapInfo schemaMapInfo, List<Long> modelList) {
if (!plugin.isContainsAllModel()) {
return modelList;
}
boolean noElementMatch = true;
for (Long model : modelList) {
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(model);
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
noElementMatch = false;
}
}
if (noElementMatch) {
return modelList.subList(0, 1);
}
return modelList;
}
private void buildQuery(Plugin plugin, double distance, Long modelId,
QueryContext queryContext, List<SchemaElementMatch> schemaElementMatches) {
log.info("EmbeddingBasedParser Model: {} choose plugin: [{} {}]", modelId, plugin.getId(), plugin.getName());
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, plugin, queryContext.getRequest(),
schemaElementMatches, distance);
double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
semanticParseInfo.setScore(score);
pluginQuery.setParseInfo(semanticParseInfo);
queryContext.getCandidateQueries().add(pluginQuery);
}
private SemanticParseInfo buildSemanticParseInfo(Long modelId, Plugin plugin, QueryReq queryReq,
List<SchemaElementMatch> schemaElementMatches, double distance) {
if (modelId == null && !CollectionUtils.isEmpty(plugin.getModelList())) {
modelId = plugin.getModelList().get(0);
}
SchemaElement Model = new SchemaElement();
Model.setModel(modelId);
Model.setId(modelId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setModel(Model);
Map<String, Object> properties = new HashMap<>();
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);
pluginParseResult.setRequest(queryReq);
pluginParseResult.setDistance(distance);
properties.put(Constants.CONTEXT, pluginParseResult);
semanticParseInfo.setProperties(properties);
semanticParseInfo.setScore(distance);
fillSemanticParseInfo(semanticParseInfo);
setEntity(modelId, semanticParseInfo);
return semanticParseInfo;
}
private void setEntity(Long modelId, SemanticParseInfo semanticParseInfo) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
ModelSchema ModelSchema = semanticService.getModelSchema(modelId);
if (ModelSchema != null && ModelSchema.getEntity() != null) {
semanticParseInfo.setEntity(ModelSchema.getEntity());
}
}
public List<RecallRetrieval> recallResult(String embeddingText) {
try {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
EmbeddingResp embeddingResp = pluginManager.recognize(embeddingText);
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
if(!CollectionUtils.isEmpty(embeddingRetrievals)){
if (hasCandidateQuery) {
embeddingRetrievals = embeddingRetrievals.stream()
.filter(llmRetrieval -> Double.parseDouble(llmRetrieval.getDistance())<THRESHOLD)
.collect(Collectors.toList());
}
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o ->
Math.abs(Double.parseDouble(o.getDistance())))).collect(Collectors.toList());
embeddingResp.setRetrieval(embeddingRetrievals);
@@ -180,45 +158,22 @@ public class EmbeddingBasedParser implements SemanticParser {
return Lists.newArrayList();
}
private boolean hasCandidateQuery(QueryContext queryContext) {
return !CollectionUtils.isEmpty(queryContext.getCandidateQueries());
}
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.forEach(schemaElementMatch -> {
QueryFilter queryFilter = new QueryFilter();
queryFilter.setValue(schemaElementMatch.getWord());
queryFilter.setElementID(schemaElementMatch.getElement().getId());
queryFilter.setName(schemaElementMatch.getElement().getName());
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
semanticParseInfo.getDimensionFilters().add(queryFilter);
QueryFilter queryFilter = new QueryFilter();
queryFilter.setValue(schemaElementMatch.getWord());
queryFilter.setElementID(schemaElementMatch.getElement().getId());
queryFilter.setName(schemaElementMatch.getElement().getName());
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
semanticParseInfo.getDimensionFilters().add(queryFilter);
});
}
}
protected String replaceText(String text, List<SchemaElementMatch> schemaElementMatches) {
if (CollectionUtils.isEmpty(schemaElementMatches)) {
return text;
}
List<SchemaElementMatch> valueSchemaElementMatches = schemaElementMatches.stream()
.filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElementMatches) {
String detectWord = schemaElementMatch.getDetectWord();
if (StringUtils.isBlank(detectWord)) {
continue;
}
text = text.replace(detectWord, "");
}
return text;
}
}

View File

@@ -1,23 +1,26 @@
package com.tencent.supersonic.chat.parser.embedding;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.service.ConfigService;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.Objects;
import java.util.Set;
@Slf4j
@Component("EmbeddingEntityResolver")
public class EmbeddingEntityResolver {
private ConfigService configService;
public EmbeddingEntityResolver(ConfigService configService) {
@@ -25,18 +28,19 @@ public class EmbeddingEntityResolver {
}
private Long getEntityValue(Long domainId, Long entityElementId, QueryContext queryCtx, ChatContext chatCtx) {
private Long getEntityValue(Long modelId, Long entityElementId, QueryContext queryCtx, ChatContext chatCtx) {
Long entityId = null;
QueryFilters queryFilters = queryCtx.getRequest().getQueryFilters();
if (queryFilters != null) {
entityId = getEntityValueFromQueryFilter(queryFilters.getFilters());
if (entityId != null) {
log.info("get entity id:{} domain id:{} from query filter :{} ", entityId, domainId, queryFilters);
log.info("get entity id:{} model id:{} from query filter :{} ", entityId, modelId, queryFilters);
return entityId;
}
}
entityId = getEntityValueFromSchemaMapInfo(domainId, queryCtx.getMapInfo(), entityElementId);
log.info("get entity id:{} from schema map Info :{} ", entityId, JSONObject.toJSONString(queryCtx.getMapInfo()));
entityId = getEntityValueFromSchemaMapInfo(modelId, queryCtx.getMapInfo(), entityElementId);
log.info("get entity id:{} from schema map Info :{} ", entityId,
JSONObject.toJSONString(queryCtx.getMapInfo()));
if (entityId == null || entityId == 0) {
Long entityIdFromChat = getEntityValueFromParseInfo(chatCtx.getParseInfo(), entityElementId);
if (entityIdFromChat != null && entityIdFromChat > 0) {
@@ -75,8 +79,8 @@ public class EmbeddingEntityResolver {
}
private Long getEntityValueFromSchemaMapInfo(Long domainId, SchemaMapInfo schemaMapInfo, Long entityElementId) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(domainId);
private Long getEntityValueFromSchemaMapInfo(Long modelId, SchemaMapInfo schemaMapInfo, Long entityElementId) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
return null;
}

View File

@@ -1,9 +1,8 @@
package com.tencent.supersonic.chat.parser.embedding;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class EmbeddingResp {

View File

@@ -7,18 +7,20 @@ import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.config.FunctionCallInfoConfig;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.plugin.dsl.DSLQuery;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
@@ -29,10 +31,9 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
@@ -55,21 +56,12 @@ public class FunctionBasedParser implements SemanticParser {
FunctionCallInfoConfig functionCallConfig = ContextUtils.getBean(FunctionCallInfoConfig.class);
PluginService pluginService = ContextUtils.getBean(PluginService.class);
String functionUrl = functionCallConfig.getUrl();
if (StringUtils.isBlank(functionUrl) || SatisfactionChecker.check(queryCtx)) {
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
queryCtx.getRequest().getQueryText());
return;
}
Set<Long> matchedDomains = getMatchDomains(queryCtx);
List<String> functionNames = getFunctionNames(matchedDomains);
log.info("matchedDomains:{},functionNames:{}", matchedDomains, functionNames);
if (CollectionUtils.isEmpty(functionNames) || CollectionUtils.isEmpty(matchedDomains)) {
return;
}
List<PluginParseConfig> functionDOList = getFunctionDO(queryCtx.getRequest().getDomainId());
List<PluginParseConfig> functionDOList = getFunctionDO(queryCtx.getRequest().getModelId(), queryCtx);
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryCtx.getRequest().getQueryText())
.pluginConfigs(functionDOList).build();
@@ -78,7 +70,6 @@ public class FunctionBasedParser implements SemanticParser {
if (skipFunction(queryCtx, functionResp)) {
return;
}
PluginParseResult functionCallParseResult = new PluginParseResult();
String toolSelection = functionResp.getToolSelection();
Optional<Plugin> pluginOptional = pluginService.getPluginByName(toolSelection);
@@ -87,24 +78,26 @@ public class FunctionBasedParser implements SemanticParser {
return;
}
Plugin plugin = pluginOptional.get();
plugin.setParseMode(ParseMode.FUNCTION_CALL);
toolSelection = plugin.getType();
functionCallParseResult.setPlugin(plugin);
log.info("QueryManager PluginQueryModes:{}", QueryManager.getPluginQueryModes());
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(toolSelection);
DomainResolver domainResolver = ComponentFactory.getDomainResolver();
Long domainId = domainResolver.resolve(queryCtx, chatCtx, plugin.getDomainList());
log.info("FunctionBasedParser domainId:{}",domainId);
if ((Objects.isNull(domainId) || domainId <= 0) && !plugin.isContainsAllDomain()) {
log.info("domain is null, skip the parse, select tool: {}", toolSelection);
ModelResolver ModelResolver = ComponentFactory.getModelResolver();
log.info("plugin ModelList:{}", plugin.getModelList());
Pair<Boolean, List<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryCtx);
Long modelId = ModelResolver.resolve(queryCtx, chatCtx, pluginResolveResult.getRight());
log.info("FunctionBasedParser modelId:{}", modelId);
if ((Objects.isNull(modelId) || modelId <= 0) && !plugin.isContainsAllModel()) {
log.info("Model is null, skip the parse, select tool: {}", toolSelection);
return;
}
if (!plugin.getDomainList().contains(domainId) && !plugin.isContainsAllDomain()) {
if (!plugin.getModelList().contains(modelId) && !plugin.isContainsAllModel()) {
return;
}
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
if (Objects.nonNull(domainId) && domainId > 0){
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(domainId));
if (Objects.nonNull(modelId) && modelId > 0) {
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
}
functionCallParseResult.setRequest(queryCtx.getRequest());
Map<String, Object> properties = new HashMap<>();
@@ -112,21 +105,22 @@ public class FunctionBasedParser implements SemanticParser {
parseInfo.setProperties(properties);
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
parseInfo.setQueryMode(semanticQuery.getQueryMode());
SchemaElement domain = new SchemaElement();
domain.setDomain(domainId);
domain.setId(domainId);
parseInfo.setDomain(domain);
SchemaElement Model = new SchemaElement();
Model.setModel(modelId);
Model.setId(modelId);
parseInfo.setModel(Model);
queryCtx.getCandidateQueries().add(semanticQuery);
}
private Set<Long> getMatchDomains(QueryContext queryCtx) {
private Set<Long> getMatchModels(QueryContext queryCtx) {
Set<Long> result = new HashSet<>();
Long domainId = queryCtx.getRequest().getDomainId();
if (Objects.nonNull(domainId) && domainId > 0) {
result.add(domainId);
Long modelId = queryCtx.getRequest().getModelId();
if (Objects.nonNull(modelId) && modelId > 0) {
result.add(modelId);
return result;
}
return queryCtx.getMapInfo().getMatchedDomains();
return queryCtx.getMapInfo().getMatchedModels();
}
private boolean skipFunction(QueryContext queryCtx, FunctionResp functionResp) {
@@ -144,36 +138,46 @@ public class FunctionBasedParser implements SemanticParser {
return false;
}
private List<PluginParseConfig> getFunctionDO(Long domainId) {
log.info("user decide domain:{}", domainId);
private List<PluginParseConfig> getFunctionDO(Long modelId, QueryContext queryContext) {
log.info("user decide Model:{}", modelId);
List<Plugin> plugins = PluginManager.getPlugins();
List<PluginParseConfig> functionDOList = plugins.stream().filter(o -> {
if (o.getParseModeConfig() == null) {
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
if (DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
return false;
}
if (!CollectionUtils.isEmpty(o.getDomainList())) {//过滤掉没选主题域的插件
return true;
if (plugin.getParseModeConfig() == null) {
return false;
}
if (domainId == null || domainId <= 0L) {
return true;
} else {
return o.getDomainList().contains(domainId);
}
}).map(o -> {
PluginParseConfig functionCallConfig = JsonUtil.toObject(o.getParseModeConfig(),
PluginParseConfig pluginParseConfig = JsonUtil.toObject(plugin.getParseModeConfig(),
PluginParseConfig.class);
return functionCallConfig;
}).collect(Collectors.toList());
if (StringUtils.isBlank(pluginParseConfig.getName())) {
return false;
}
Pair<Boolean, List<Long>> pluginResolverResult = PluginManager.resolve(plugin, queryContext);
log.info("embedding plugin [{}-{}] resolve: {}", plugin.getId(), plugin.getName(), pluginResolverResult);
if (!pluginResolverResult.getLeft()) {
return false;
} else {
List<Long> resolveModel = pluginResolverResult.getRight();
if (modelId != null && modelId > 0) {
if (plugin.isContainsAllModel()) {
return true;
}
return resolveModel.contains(modelId);
}
return true;
}
}).map(o -> JsonUtil.toObject(o.getParseModeConfig(), PluginParseConfig.class)).collect(Collectors.toList());
log.info("getFunctionDO:{}", JsonUtil.toString(functionDOList));
return functionDOList;
}
private List<String> getFunctionNames(Set<Long> matchedDomains) {
private List<String> getFunctionNames(Set<Long> matchedModels) {
List<Plugin> plugins = PluginManager.getPlugins();
Set<String> functionNames = plugins.stream()
.filter(entry -> {
if (!CollectionUtils.isEmpty(entry.getDomainList()) && !CollectionUtils.isEmpty(matchedDomains)) {
return entry.getDomainList().stream().anyMatch(matchedDomains::contains);
if (!CollectionUtils.isEmpty(entry.getModelList()) && !CollectionUtils.isEmpty(matchedModels)) {
return entry.getModelList().stream().anyMatch(matchedModels::contains);
}
return true;
}

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.chat.parser.function;
import java.util.List;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import java.util.List;
import lombok.Builder;
import lombok.Data;

View File

@@ -1,181 +0,0 @@
package com.tencent.supersonic.chat.parser.function;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import lombok.extern.slf4j.Slf4j;
import java.util.*;
import java.util.stream.Collectors;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class HeuristicDomainResolver implements DomainResolver {
protected static Long selectDomainBySchemaElementCount(Map<Long, SemanticQuery> domainQueryModes,
SchemaMapInfo schemaMap) {
Map<Long, DomainMatchResult> domainTypeMap = getDomainTypeMap(schemaMap);
if (domainTypeMap.size() == 1) {
Long domainSelect = domainTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
if (domainQueryModes.containsKey(domainSelect)) {
log.info("selectDomain with only one domain [{}]", domainSelect);
return domainSelect;
}
} else {
Map.Entry<Long, DomainMatchResult> maxDomain = domainTypeMap.entrySet().stream()
.filter(entry -> domainQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> {
int difference = o2.getValue().getCount() - o1.getValue().getCount();
if (difference == 0) {
return (int) ((o2.getValue().getMaxSimilarity()
- o1.getValue().getMaxSimilarity()) * 100);
}
return difference;
}).findFirst().orElse(null);
if (maxDomain != null) {
log.info("selectDomain with multiple domains [{}]", maxDomain.getKey());
return maxDomain.getKey();
}
}
return 0L;
}
/**
* to check can switch domain if context exit domain
*
* @return false will use context domain, true will use other domain , maybe include context domain
*/
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> domainQueryModes, SchemaMapInfo schemaMap,
ChatContext chatCtx, QueryReq searchCtx, Long domainId, List<Long> restrictiveDomains) {
if (!Objects.nonNull(domainId) || domainId <= 0) {
return true;
}
// except content domain, calculate the number of types for each domain, if numbers<=1 will not switch
Map<Long, DomainMatchResult> domainTypeMap = getDomainTypeMap(schemaMap);
log.info("isAllowSwitch domainTypeMap [{}]", domainTypeMap);
long otherDomainTypeNumBigOneCount = domainTypeMap.entrySet().stream()
.filter(entry -> domainQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(domainId))
.filter(entry -> entry.getValue().getCount() > 1).count();
if (otherDomainTypeNumBigOneCount >= 1) {
return true;
}
// if query text only contain time , will not switch
if (!CollectionUtils.isEmpty(domainQueryModes.values())) {
for (SemanticQuery semanticQuery : domainQueryModes.values()) {
if (semanticQuery == null) {
continue;
}
SemanticParseInfo semanticParseInfo = semanticQuery.getParseInfo();
if (semanticParseInfo == null) {
continue;
}
if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) {
if (semanticParseInfo.getDateInfo().getText() != null) {
if (semanticParseInfo.getDateInfo().getText().equalsIgnoreCase(searchCtx.getQueryText())) {
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
semanticParseInfo.getDateInfo());
return false;
}
}
}
}
}
if (CollectionUtils.isNotEmpty(restrictiveDomains) && !restrictiveDomains.contains(domainId)) {
return true;
}
// if context domain not in schemaMap , will switch
if (schemaMap.getMatchedElements(domainId) == null || schemaMap.getMatchedElements(domainId).size() <= 0) {
log.info("domainId not in schemaMap ");
return true;
}
// other will not switch
return false;
}
public static Map<Long, DomainMatchResult> getDomainTypeMap(SchemaMapInfo schemaMap) {
Map<Long, DomainMatchResult> domainCount = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getDomainElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!domainCount.containsKey(entry.getKey())) {
domainCount.put(entry.getKey(), new DomainMatchResult());
}
DomainMatchResult domainMatchResult = domainCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream()
.forEach(schemaElementMatch -> schemaElementTypes.add(
schemaElementMatch.getElement().getType()));
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
.sorted((o1, o2) ->
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
).findFirst().orElse(null);
if (schemaElementMatchMax != null) {
domainMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
}
domainMatchResult.setCount(schemaElementTypes.size());
}
}
return domainCount;
}
public Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveDomains) {
Long domainId = queryContext.getRequest().getDomainId();
if (Objects.nonNull(domainId) && domainId > 0) {
if (CollectionUtils.isNotEmpty(restrictiveDomains) && restrictiveDomains.contains(domainId)) {
return domainId;
} else {
return null;
}
}
SchemaMapInfo mapInfo = queryContext.getMapInfo();
Set<Long> matchedDomains = mapInfo.getMatchedDomains();
if (CollectionUtils.isNotEmpty(restrictiveDomains)) {
matchedDomains = matchedDomains.stream()
.filter(matchedDomain -> restrictiveDomains.contains(matchedDomain))
.collect(Collectors.toSet());
}
Map<Long, SemanticQuery> domainQueryModes = new HashMap<>();
for (Long matchedDomain : matchedDomains) {
domainQueryModes.put(matchedDomain, null);
}
if(domainQueryModes.size()==1){
return domainQueryModes.keySet().stream().findFirst().get();
}
return resolve(domainQueryModes, queryContext, chatCtx,
queryContext.getMapInfo(),restrictiveDomains);
}
public Long resolve(Map<Long, SemanticQuery> domainQueryModes, QueryContext queryContext,
ChatContext chatCtx, SchemaMapInfo schemaMap, List<Long> restrictiveDomains) {
Long selectDomain = selectDomain(domainQueryModes, queryContext.getRequest(), chatCtx, schemaMap,restrictiveDomains);
if (selectDomain > 0) {
log.info("selectDomain {} ", selectDomain);
return selectDomain;
}
// get the max SchemaElementType number
return selectDomainBySchemaElementCount(domainQueryModes, schemaMap);
}
public Long selectDomain(Map<Long, SemanticQuery> domainQueryModes, QueryReq queryContext,
ChatContext chatCtx,
SchemaMapInfo schemaMap, List<Long> restrictiveDomains) {
// if QueryContext has domainId and in domainQueryModes
if (domainQueryModes.containsKey(queryContext.getDomainId())) {
log.info("selectDomain from QueryContext [{}]", queryContext.getDomainId());
return queryContext.getDomainId();
}
// if ChatContext has domainId and in domainQueryModes
if (chatCtx.getParseInfo().getDomainId() > 0) {
Long domainId = chatCtx.getParseInfo().getDomainId();
if (!isAllowSwitch(domainQueryModes, schemaMap, chatCtx, queryContext, domainId,restrictiveDomains)) {
log.info("selectDomain from ChatContext [{}]", domainId);
return domainId;
}
}
// default 0
return 0L;
}
}

View File

@@ -0,0 +1,192 @@
package com.tencent.supersonic.chat.parser.function;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class HeuristicModelResolver implements ModelResolver {
protected static Long selectModelBySchemaElementCount(Map<Long, SemanticQuery> ModelQueryModes,
SchemaMapInfo schemaMap) {
Map<Long, ModelMatchResult> ModelTypeMap = getModelTypeMap(schemaMap);
if (ModelTypeMap.size() == 1) {
Long ModelSelect = ModelTypeMap.entrySet().stream().collect(Collectors.toList()).get(0).getKey();
if (ModelQueryModes.containsKey(ModelSelect)) {
log.info("selectModel with only one Model [{}]", ModelSelect);
return ModelSelect;
}
} else {
Map.Entry<Long, ModelMatchResult> maxModel = ModelTypeMap.entrySet().stream()
.filter(entry -> ModelQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> {
int difference = o2.getValue().getCount() - o1.getValue().getCount();
if (difference == 0) {
return (int) ((o2.getValue().getMaxSimilarity()
- o1.getValue().getMaxSimilarity()) * 100);
}
return difference;
}).findFirst().orElse(null);
if (maxModel != null) {
log.info("selectModel with multiple Models [{}]", maxModel.getKey());
return maxModel.getKey();
}
}
return 0L;
}
/**
* to check can switch Model if context exit Model
*
* @return false will use context Model, true will use other Model , maybe include context Model
*/
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> ModelQueryModes, SchemaMapInfo schemaMap,
ChatContext chatCtx, QueryReq searchCtx, Long modelId, List<Long> restrictiveModels) {
if (!Objects.nonNull(modelId) || modelId <= 0) {
return true;
}
// except content Model, calculate the number of types for each Model, if numbers<=1 will not switch
Map<Long, ModelMatchResult> ModelTypeMap = getModelTypeMap(schemaMap);
log.info("isAllowSwitch ModelTypeMap [{}]", ModelTypeMap);
long otherModelTypeNumBigOneCount = ModelTypeMap.entrySet().stream()
.filter(entry -> ModelQueryModes.containsKey(entry.getKey()) && !entry.getKey().equals(modelId))
.filter(entry -> entry.getValue().getCount() > 1).count();
if (otherModelTypeNumBigOneCount >= 1) {
return true;
}
// if query text only contain time , will not switch
if (!CollectionUtils.isEmpty(ModelQueryModes.values())) {
for (SemanticQuery semanticQuery : ModelQueryModes.values()) {
if (semanticQuery == null) {
continue;
}
SemanticParseInfo semanticParseInfo = semanticQuery.getParseInfo();
if (semanticParseInfo == null) {
continue;
}
if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) {
if (semanticParseInfo.getDateInfo().getDetectWord() != null) {
if (semanticParseInfo.getDateInfo().getDetectWord()
.equalsIgnoreCase(searchCtx.getQueryText())) {
log.info("timeParseResults is not null , can not switch context , timeParseResults:{},",
semanticParseInfo.getDateInfo());
return false;
}
}
}
}
}
if (CollectionUtils.isNotEmpty(restrictiveModels) && !restrictiveModels.contains(modelId)) {
return true;
}
// if context Model not in schemaMap , will switch
if (schemaMap.getMatchedElements(modelId) == null || schemaMap.getMatchedElements(modelId).size() <= 0) {
log.info("modelId not in schemaMap ");
return true;
}
// other will not switch
return false;
}
public static Map<Long, ModelMatchResult> getModelTypeMap(SchemaMapInfo schemaMap) {
Map<Long, ModelMatchResult> ModelCount = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getModelElementMatches().entrySet()) {
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
if (!ModelCount.containsKey(entry.getKey())) {
ModelCount.put(entry.getKey(), new ModelMatchResult());
}
ModelMatchResult ModelMatchResult = ModelCount.get(entry.getKey());
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
schemaElementMatches.stream()
.forEach(schemaElementMatch -> schemaElementTypes.add(
schemaElementMatch.getElement().getType()));
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
.sorted((o1, o2) ->
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
).findFirst().orElse(null);
if (schemaElementMatchMax != null) {
ModelMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
}
ModelMatchResult.setCount(schemaElementTypes.size());
}
}
return ModelCount;
}
public Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveModels) {
Long modelId = queryContext.getRequest().getModelId();
if (Objects.nonNull(modelId) && modelId > 0) {
if (CollectionUtils.isNotEmpty(restrictiveModels) && restrictiveModels.contains(modelId)) {
return modelId;
} else {
return null;
}
}
SchemaMapInfo mapInfo = queryContext.getMapInfo();
Set<Long> matchedModels = mapInfo.getMatchedModels();
if (CollectionUtils.isNotEmpty(restrictiveModels)) {
matchedModels = matchedModels.stream()
.filter(restrictiveModels::contains)
.collect(Collectors.toSet());
}
Map<Long, SemanticQuery> ModelQueryModes = new HashMap<>();
for (Long matchedModel : matchedModels) {
ModelQueryModes.put(matchedModel, null);
}
if (ModelQueryModes.size() == 1) {
return ModelQueryModes.keySet().stream().findFirst().get();
}
return resolve(ModelQueryModes, queryContext, chatCtx,
queryContext.getMapInfo(), restrictiveModels);
}
public Long resolve(Map<Long, SemanticQuery> ModelQueryModes, QueryContext queryContext,
ChatContext chatCtx, SchemaMapInfo schemaMap, List<Long> restrictiveModels) {
Long selectModel = selectModel(ModelQueryModes, queryContext.getRequest(), chatCtx, schemaMap,
restrictiveModels);
if (selectModel > 0) {
log.info("selectModel {} ", selectModel);
return selectModel;
}
// get the max SchemaElementType number
return selectModelBySchemaElementCount(ModelQueryModes, schemaMap);
}
public Long selectModel(Map<Long, SemanticQuery> ModelQueryModes, QueryReq queryContext,
ChatContext chatCtx,
SchemaMapInfo schemaMap, List<Long> restrictiveModels) {
// if QueryContext has modelId and in ModelQueryModes
if (ModelQueryModes.containsKey(queryContext.getModelId())) {
log.info("selectModel from QueryContext [{}]", queryContext.getModelId());
return queryContext.getModelId();
}
// if ChatContext has modelId and in ModelQueryModes
if (chatCtx.getParseInfo().getModelId() > 0) {
Long modelId = chatCtx.getParseInfo().getModelId();
if (!isAllowSwitch(ModelQueryModes, schemaMap, chatCtx, queryContext, modelId, restrictiveModels)) {
log.info("selectModel from ChatContext [{}]", modelId);
return modelId;
}
}
// default 0
return 0L;
}
}

View File

@@ -3,7 +3,8 @@ package com.tencent.supersonic.chat.parser.function;
import lombok.Data;
@Data
public class DomainMatchResult {
public class ModelMatchResult {
private Integer count = 0;
private double maxSimilarity;
}

View File

@@ -5,8 +5,8 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import java.util.List;
public interface DomainResolver {
public interface ModelResolver {
Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveDomains);
Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveModels);
}

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.parser.function;
import lombok.Data;
import java.util.List;
import java.util.Map;
import lombok.Data;
@Data
public class Parameters {

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.chat.parser.llm;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.dsl.LLMResp;
import lombok.Data;
@Data
public class DSLParseResult extends PluginParseResult {
private LLMResp llmResp;
}

View File

@@ -0,0 +1,231 @@
package com.tencent.supersonic.chat.parser.llm;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.config.LLMConfig;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.parser.function.ModelResolver;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.dsl.DSLBuilder;
import com.tencent.supersonic.chat.query.dsl.DSLQuery;
import com.tencent.supersonic.chat.query.dsl.LLMReq;
import com.tencent.supersonic.chat.query.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.chat.query.dsl.LLMResp;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.DateUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.RestTemplate;
@Slf4j
public class LLMDSLParser implements SemanticParser {
public static final double FUNCTION_BONUS_THRESHOLD = 201;
@Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
final LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class);
if (StringUtils.isEmpty(llmConfig.getUrl()) || SatisfactionChecker.check(queryCtx)) {
log.info("llmConfig:{}, skip function parser, queryText:{}", llmConfig,
queryCtx.getRequest().getQueryText());
return;
}
List<Plugin> dslPlugins = PluginManager.getPlugins().stream()
.filter(plugin -> DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(dslPlugins)) {
return;
}
Plugin plugin = dslPlugins.get(0);
List<Long> dslModels = plugin.getModelList();
try {
ModelResolver modelResolver = ComponentFactory.getModelResolver();
Long modelId = modelResolver.resolve(queryCtx, chatCtx, dslModels);
log.info("resolve modelId:{},dslModels:{}", modelId, dslModels);
if (Objects.isNull(modelId)) {
return;
}
LLMResp llmResp = requestLLM(queryCtx, modelId);
if (Objects.isNull(llmResp)) {
return;
}
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DSLQuery.QUERY_MODE);
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
if (Objects.nonNull(modelId) && modelId > 0) {
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(modelId));
}
DSLParseResult dslParseResult = new DSLParseResult();
dslParseResult.setRequest(queryCtx.getRequest());
dslParseResult.setLlmResp(llmResp);
dslParseResult.setPlugin(plugin);
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
parseInfo.setProperties(properties);
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
parseInfo.setQueryMode(semanticQuery.getQueryMode());
SchemaElement Model = new SchemaElement();
Model.setModel(modelId);
Model.setId(modelId);
parseInfo.setModel(Model);
queryCtx.getCandidateQueries().add(semanticQuery);
} catch (Exception e) {
log.error("LLMDSLParser error", e);
}
}
private LLMResp requestLLM(QueryContext queryCtx, Long modelId) {
long startTime = System.currentTimeMillis();
String queryText = queryCtx.getRequest().getQueryText();
final LLMConfig llmConfig = ContextUtils.getBean(LLMConfig.class);
if (StringUtils.isEmpty(llmConfig.getUrl())) {
log.warn("llmConfig url is null, skip llm parser");
return null;
}
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Map<Long, String> modelIdToName = semanticSchema.getModelIdToName();
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
llmSchema.setModelName(modelIdToName.get(modelId));
llmSchema.setDomainName(modelIdToName.get(modelId));
List<String> fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema);
fieldNameList.add(DSLBuilder.DATA_Field);
llmSchema.setFieldNameList(fieldNameList);
llmReq.setSchema(llmSchema);
List<ElementValue> linking = new ArrayList<>();
linking.addAll(getValueList(queryCtx, modelId, semanticSchema));
llmReq.setLinking(linking);
String currentDate = getCurrentDate(modelId);
llmReq.setCurrentDate(currentDate);
log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq);
String questUrl = llmConfig.getUrl() + llmConfig.getQueryToSqlPath();
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(questUrl, HttpMethod.POST, entity,
LLMResp.class);
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, questUrl, entity, responseEntity.getBody());
return responseEntity.getBody();
} catch (Exception e) {
log.error("requestLLM error", e);
}
return null;
}
private String getCurrentDate(Long modelId) {
return DateUtils.getBeforeDate(4);
// ChatConfigFilter filter = new ChatConfigFilter();
// filter.setModelId(modelId);
//
// List<ChatConfigResp> configResps = ContextUtils.getBean(ConfigService.class).search(filter, null);
// if (CollectionUtils.isEmpty(configResps)) {
// return
// }
// ChatConfigResp chatConfigResp = configResps.get(0);
// chatConfigResp.getChatDetailConfig().getChatDefaultConfig().get
}
private List<ElementValue> getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
}
Set<ElementValue> valueMatches = matchedElements.stream()
.filter(schemaElementMatch -> {
SchemaElementType type = schemaElementMatch.getElement().getType();
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
})
.map(elementMatch ->
{
ElementValue elementValue = new ElementValue();
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
elementValue.setFieldValue(elementMatch.getWord());
return elementValue;
}
)
.collect(Collectors.toSet());
return new ArrayList<>(valueMatches);
}
private List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
}
Set<String> fieldNameList = matchedElements.stream()
.filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType) ||
SchemaElementType.DIMENSION.equals(elementType) ||
SchemaElementType.VALUE.equals(elementType);
})
.map(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
if (!SchemaElementType.VALUE.equals(elementType)) {
return schemaElementMatch.getWord();
}
return itemIdToName.get(schemaElementMatch.getElement().getId());
})
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
.collect(Collectors.toSet());
return new ArrayList<>(fieldNameList);
}
private Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
return semanticSchema.getDimensions().stream()
.filter(entry -> modelId.equals(entry.getModel()))
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}
}

View File

@@ -17,7 +17,7 @@ public class LLMTimeEnhancementParse implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
log.info("before queryContext:{},chatContext:{}",queryContext,chatContext);
log.info("before queryContext:{},chatContext:{}", queryContext, chatContext);
ChatGptHelper chatGptHelper = ContextUtils.getBean(ChatGptHelper.class);
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
try {
@@ -25,12 +25,12 @@ public class LLMTimeEnhancementParse implements SemanticParser {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
DateConf dateInfo = query.getParseInfo().getDateInfo();
JSONObject jsonObject = JSON.parseObject(inferredTime);
if (jsonObject.containsKey("date")){
if (jsonObject.containsKey("date")) {
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(jsonObject.getString("date"));
dateInfo.setEndDate(jsonObject.getString("date"));
query.getParseInfo().setDateInfo(dateInfo);
}else if (jsonObject.containsKey("start")){
} else if (jsonObject.containsKey("start")) {
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(jsonObject.getString("start"));
dateInfo.setEndDate(jsonObject.getString("end"));
@@ -38,11 +38,12 @@ public class LLMTimeEnhancementParse implements SemanticParser {
}
}
}
}catch (Exception exception){
log.error("{} parse error,this reason is:{}",LLMTimeEnhancementParse.class.getSimpleName(), (Object) exception.getStackTrace());
} catch (Exception exception) {
log.error("{} parse error,this reason is:{}", LLMTimeEnhancementParse.class.getSimpleName(),
(Object) exception.getStackTrace());
}
log.info("after queryContext:{},chatContext:{}",queryContext,chatContext);
log.info("after queryContext:{},chatContext:{}", queryContext, chatContext);
}

View File

@@ -21,7 +21,9 @@ import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class AggregateTypeParser implements SemanticParser {
@@ -35,36 +37,60 @@ public class AggregateTypeParser implements SemanticParser {
new AbstractMap.SimpleEntry<>(DISTINCT, Pattern.compile("(?i)(uv)")),
new AbstractMap.SimpleEntry<>(COUNT, Pattern.compile("(?i)(总数|pv)")),
new AbstractMap.SimpleEntry<>(NONE, Pattern.compile("(?i)(明细)"))
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,(k1,k2)->k2));
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
String queryText = queryContext.getRequest().getQueryText();
AggregateConf aggregateConf = resolveAggregateConf(queryText);
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
if (!AggregateTypeEnum.NONE.equals(semanticQuery.getParseInfo().getAggType())) {
continue;
}
String queryText = queryContext.getRequest().getQueryText();
semanticQuery.getParseInfo().setAggType(resolveAggregateType(queryText));
semanticQuery.getParseInfo().setAggType(aggregateConf.type);
int detectWordLength = 0;
if (StringUtils.isNotEmpty(aggregateConf.detectWord)) {
detectWordLength = aggregateConf.detectWord.length();
}
semanticQuery.getParseInfo().setScore(semanticQuery.getParseInfo().getScore() + detectWordLength);
}
}
public static AggregateTypeEnum resolveAggregateType(String queryText) {
public AggregateTypeEnum resolveAggregateType(String queryText) {
AggregateConf aggregateConf = resolveAggregateConf(queryText);
return aggregateConf.type;
}
private AggregateConf resolveAggregateConf(String queryText) {
Map<AggregateTypeEnum, Integer> aggregateCount = new HashMap<>(REGX_MAP.size());
Map<AggregateTypeEnum, String> aggregateWord = new HashMap<>(REGX_MAP.size());
for (Map.Entry<AggregateTypeEnum, Pattern> entry : REGX_MAP.entrySet()) {
Matcher matcher = entry.getValue().matcher(queryText);
int count = 0;
String detectWord = null;
while (matcher.find()) {
count++;
detectWord = matcher.group();
}
if (count > 0) {
aggregateCount.put(entry.getKey(), count);
aggregateWord.put(entry.getKey(), detectWord);
}
}
return aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue())
AggregateTypeEnum type = aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue())
.map(entry -> entry.getKey()).orElse(NONE);
String detectWord = aggregateWord.get(type);
return new AggregateConf(type, detectWord);
}
@AllArgsConstructor
class AggregateConf {
public AggregateTypeEnum type;
public String detectWord;
}
}

View File

@@ -1,5 +1,12 @@
package com.tencent.supersonic.chat.parser.rule;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DIMENSION;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ENTITY;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.MODEL;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
@@ -8,8 +15,8 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricDomainQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.AbstractMap;
import java.util.ArrayList;
@@ -21,8 +28,6 @@ import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
@Slf4j
public class ContextInheritParser implements SemanticParser {
@@ -31,7 +36,7 @@ public class ContextInheritParser implements SemanticParser {
new AbstractMap.SimpleEntry<>(DIMENSION, Arrays.asList(DIMENSION, VALUE)),
new AbstractMap.SimpleEntry<>(VALUE, Arrays.asList(VALUE, DIMENSION)),
new AbstractMap.SimpleEntry<>(ENTITY, Arrays.asList(ENTITY)),
new AbstractMap.SimpleEntry<>(DOMAIN, Arrays.asList(DOMAIN)),
new AbstractMap.SimpleEntry<>(MODEL, Arrays.asList(MODEL)),
new AbstractMap.SimpleEntry<>(ID, Arrays.asList(ID))
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
@@ -41,9 +46,9 @@ public class ContextInheritParser implements SemanticParser {
return;
}
Long domainId = chatContext.getParseInfo().getDomainId();
Long modelId = chatContext.getParseInfo().getModelId();
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo()
.getMatchedElements(domainId);
.getMatchedElements(modelId);
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
@@ -51,7 +56,7 @@ public class ContextInheritParser implements SemanticParser {
// mutual exclusive element types should not be inherited
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(chatContext.getParseInfo().getQueryMode());
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
match.setMode(SchemaElementMatch.MatchMode.INHERIT);
match.setInherited(true);
matchesToInherit.add(match);
}
}
@@ -59,11 +64,24 @@ public class ContextInheritParser implements SemanticParser {
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(domainId, chatContext);
query.fillParseInfo(modelId, queryContext, chatContext);
if (existSameQuery(query.getParseInfo().getModelId(), query.getQueryMode(), queryContext)) {
continue;
}
queryContext.getCandidateQueries().add(query);
}
}
private boolean existSameQuery(Long modelId, String queryMode, QueryContext queryContext) {
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
if (semanticQuery.getQueryMode().equals(queryMode)
&& semanticQuery.getParseInfo().getModelId().equals(modelId)) {
return true;
}
}
return false;
}
private boolean containsTypes(List<SchemaElementMatch> matches, SchemaElementType matchType,
RuleSemanticQuery ruleQuery) {
List<SchemaElementType> types = MUTUAL_EXCLUSIVE_MAP.get(matchType);
@@ -79,16 +97,18 @@ public class ContextInheritParser implements SemanticParser {
}
protected boolean shouldInherit(QueryContext queryContext, ChatContext chatContext) {
Long contextDomainId = chatContext.getParseInfo().getDomainId();
if (queryContext.getMapInfo().getMatchedElements(contextDomainId) == null) {
Long contextmodelId = chatContext.getParseInfo().getModelId();
// if map info doesn't contain the same Model of the context,
// no inheritance could be done
if (queryContext.getMapInfo().getMatchedElements(contextmodelId) == null) {
return false;
}
// if candidates have only one MetricDomain mode and context has value filter , count in context
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries().stream()
.filter(semanticQuery -> semanticQuery.getParseInfo().getDomainId().equals(contextDomainId)).collect(
// if candidates only have MetricModel mode, count in context
List<SemanticQuery> metricModelQueries = queryContext.getCandidateQueries().stream()
.filter(query -> query instanceof MetricModelQuery).collect(
Collectors.toList());
if (candidateQueries.size() == 1 && (candidateQueries.get(0) instanceof MetricDomainQuery)) {
if (metricModelQueries.size() == queryContext.getCandidateQueries().size()) {
return true;
} else {
return queryContext.getCandidateQueries().size() == 0;

View File

@@ -1,11 +1,12 @@
package com.tencent.supersonic.chat.parser.rule;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.*;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
@Slf4j
@@ -16,12 +17,12 @@ public class QueryModeParser implements SemanticParser {
SchemaMapInfo mapInfo = queryContext.getMapInfo();
// iterate all schemaElementMatches to resolve semantic query
for (Long domainId : mapInfo.getMatchedDomains()) {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(domainId);
for (Long modelId : mapInfo.getMatchedModels()) {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(modelId);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(elementMatches, queryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(domainId, chatContext);
query.fillParseInfo(modelId, queryContext, chatContext);
queryContext.getCandidateQueries().add(query);
}
}

View File

@@ -4,21 +4,21 @@ import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.xkzhangsan.time.nlp.TimeNLP;
import com.xkzhangsan.time.nlp.TimeNLPUtil;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.time.LocalDate;
import java.util.*;
import java.util.Date;
import java.util.List;
import java.util.Stack;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import com.xkzhangsan.time.nlp.TimeNLP;
import com.xkzhangsan.time.nlp.TimeNLPUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
@@ -45,12 +45,16 @@ public class TimeRangeParser implements SemanticParser {
if (queryContext.getCandidateQueries().size() > 0) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
query.getParseInfo().setDateInfo(dateConf);
query.getParseInfo().setScore(query.getParseInfo().getScore()
+ dateConf.getDetectWord().length());
}
} else if (QueryManager.containsRuleQuery(chatContext.getParseInfo().getQueryMode())) {
RuleSemanticQuery semanticQuery = QueryManager.createRuleQuery(
chatContext.getParseInfo().getQueryMode());
// inherit parse info from context
chatContext.getParseInfo().setDateInfo(dateConf);
chatContext.getParseInfo().setScore(chatContext.getParseInfo().getScore()
+ dateConf.getDetectWord().length());
semanticQuery.setParseInfo(chatContext.getParseInfo());
queryContext.getCandidateQueries().add(semanticQuery);
}
@@ -60,41 +64,48 @@ public class TimeRangeParser implements SemanticParser {
private DateConf parseDateCN(String queryText) {
Date startDate = null;
Date endDate;
String detectWord = null;
List<TimeNLP> times = TimeNLPUtil.parse(queryText);
if (times.size() > 0) {
startDate = times.get(0).getTime();
detectWord = times.get(0).getTimeExpression();
} else {
return null;
}
if (times.size() > 1) {
endDate = times.get(1).getTime();
detectWord += "~" + times.get(0).getTimeExpression();
} else {
endDate = startDate;
}
return getDateConf(startDate, endDate);
return getDateConf(startDate, endDate, detectWord);
}
private DateConf parseDateNumber(String queryText) {
String startDate;
String endDate = null;
String detectWord = null;
Matcher dateMatcher = DATE_PATTERN_NUMBER.matcher(queryText);
if (dateMatcher.find()) {
startDate = dateMatcher.group();
detectWord = startDate;
} else {
return null;
}
if (dateMatcher.find()) {
endDate = dateMatcher.group();
detectWord += "~" + endDate;
}
endDate = endDate != null ? endDate : startDate;
try {
return getDateConf(DATE_FORMAT_NUMBER.parse(startDate), DATE_FORMAT_NUMBER.parse(endDate));
return getDateConf(DATE_FORMAT_NUMBER.parse(startDate), DATE_FORMAT_NUMBER.parse(endDate), detectWord);
} catch (ParseException e) {
return null;
}
@@ -134,11 +145,11 @@ public class TimeRangeParser implements SemanticParser {
}
days = days * num;
info.setDateMode(DateConf.DateMode.RECENT);
String text = "" + num + zhPeriod;
String detectWord = "" + num + zhPeriod;
if (Strings.isNotEmpty(m.group("periodStr"))) {
text = m.group("periodStr");
detectWord = m.group("periodStr");
}
info.setText(text);
info.setDetectWord(detectWord);
info.setStartDate(LocalDate.now().minusDays(days).toString());
info.setUnit(num);
@@ -173,7 +184,7 @@ public class TimeRangeParser implements SemanticParser {
return stack.stream().mapToInt(s -> s).sum();
}
private DateConf getDateConf(Date startDate, Date endDate) {
private DateConf getDateConf(Date startDate, Date endDate, String detectWord) {
if (startDate == null || endDate == null) {
return null;
}
@@ -182,6 +193,7 @@ public class TimeRangeParser implements SemanticParser {
info.setDateMode(DateConf.DateMode.BETWEEN);
info.setStartDate(DATE_FORMAT.format(startDate));
info.setEndDate(DATE_FORMAT.format(endDate));
info.setDetectWord(detectWord);
return info;
}

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.Date;
import lombok.Data;
import lombok.ToString;
@@ -15,7 +14,7 @@ public class ChatConfigDO {
*/
private Long id;
private Long domainId;
private Long modelId;
private String chatDetailConfig;

View File

@@ -1,11 +1,10 @@
package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.ArrayList;
import java.util.List;
import com.tencent.supersonic.chat.config.DefaultMetric;
import com.tencent.supersonic.chat.config.Dim4Dict;
import java.util.ArrayList;
import java.util.List;
import lombok.Data;
import lombok.ToString;
@@ -14,14 +13,14 @@ import lombok.ToString;
@ToString
public class DimValueDO {
private Long domainId;
private Long modelId;
private List<DefaultMetric> defaultMetricDescList = new ArrayList<>();
private List<Dim4Dict> dimensions = new ArrayList<>();
public DimValueDO setDomainId(Long domainId) {
this.domainId = domainId;
public DimValueDO setModelId(Long modelId) {
this.modelId = modelId;
return this;
}

View File

@@ -3,8 +3,9 @@ package com.tencent.supersonic.chat.persistence.dataobject;
import java.util.Date;
public class PluginDO {
/**
*
*
*/
private Long id;
@@ -14,71 +15,69 @@ public class PluginDO {
private String type;
/**
*
*
*/
private String domain;
private String model;
/**
*
*
*/
private String pattern;
/**
*
*
*/
private String parseMode;
/**
*
*
*/
private String name;
/**
*
*
*/
private Date createdAt;
/**
*
*
*/
private String createdBy;
/**
*
*
*/
private Date updatedAt;
/**
*
*
*/
private String updatedBy;
/**
*
*
*/
private String parseModeConfig;
/**
*
*
*/
private String config;
/**
*
*
*/
private String comment;
/**
*
* @return id
* @return id
*/
public Long getId() {
return id;
}
/**
*
* @param id
* @param id
*/
public void setId(Long id) {
this.id = id;
@@ -86,6 +85,7 @@ public class PluginDO {
/**
* DASHBOARD,WIDGET,URL
*
* @return type DASHBOARD,WIDGET,URL
*/
public String getType() {
@@ -94,6 +94,7 @@ public class PluginDO {
/**
* DASHBOARD,WIDGET,URL
*
* @param type DASHBOARD,WIDGET,URL
*/
public void setType(String type) {
@@ -101,176 +102,154 @@ public class PluginDO {
}
/**
*
* @return domain
* @return model
*/
public String getDomain() {
return domain;
public String getModel() {
return model;
}
/**
*
* @param domain
* @param model
*/
public void setDomain(String domain) {
this.domain = domain == null ? null : domain.trim();
public void setModel(String model) {
this.model = model == null ? null : model.trim();
}
/**
*
* @return pattern
* @return pattern
*/
public String getPattern() {
return pattern;
}
/**
*
* @param pattern
* @param pattern
*/
public void setPattern(String pattern) {
this.pattern = pattern == null ? null : pattern.trim();
}
/**
*
* @return parse_mode
* @return parse_mode
*/
public String getParseMode() {
return parseMode;
}
/**
*
* @param parseMode
* @param parseMode
*/
public void setParseMode(String parseMode) {
this.parseMode = parseMode == null ? null : parseMode.trim();
}
/**
*
* @return name
* @return name
*/
public String getName() {
return name;
}
/**
*
* @param name
* @param name
*/
public void setName(String name) {
this.name = name == null ? null : name.trim();
}
/**
*
* @return created_at
* @return created_at
*/
public Date getCreatedAt() {
return createdAt;
}
/**
*
* @param createdAt
* @param createdAt
*/
public void setCreatedAt(Date createdAt) {
this.createdAt = createdAt;
}
/**
*
* @return created_by
* @return created_by
*/
public String getCreatedBy() {
return createdBy;
}
/**
*
* @param createdBy
* @param createdBy
*/
public void setCreatedBy(String createdBy) {
this.createdBy = createdBy == null ? null : createdBy.trim();
}
/**
*
* @return updated_at
* @return updated_at
*/
public Date getUpdatedAt() {
return updatedAt;
}
/**
*
* @param updatedAt
* @param updatedAt
*/
public void setUpdatedAt(Date updatedAt) {
this.updatedAt = updatedAt;
}
/**
*
* @return updated_by
* @return updated_by
*/
public String getUpdatedBy() {
return updatedBy;
}
/**
*
* @param updatedBy
* @param updatedBy
*/
public void setUpdatedBy(String updatedBy) {
this.updatedBy = updatedBy == null ? null : updatedBy.trim();
}
/**
*
* @return parse_mode_config
* @return parse_mode_config
*/
public String getParseModeConfig() {
return parseModeConfig;
}
/**
*
* @param parseModeConfig
* @param parseModeConfig
*/
public void setParseModeConfig(String parseModeConfig) {
this.parseModeConfig = parseModeConfig == null ? null : parseModeConfig.trim();
}
/**
*
* @return config
* @return config
*/
public String getConfig() {
return config;
}
/**
*
* @param config
* @param config
*/
public void setConfig(String config) {
this.config = config == null ? null : config.trim();
}
/**
*
* @return comment
* @return comment
*/
public String getComment() {
return comment;
}
/**
*
* @param comment
* @param comment
*/
public void setComment(String comment) {
this.comment = comment == null ? null : comment.trim();

View File

@@ -5,6 +5,7 @@ import java.util.Date;
import java.util.List;
public class PluginDOExample {
/**
* s2_plugin
*/
@@ -31,7 +32,6 @@ public class PluginDOExample {
protected Integer limitEnd;
/**
*
* @mbg.generated
*/
public PluginDOExample() {
@@ -39,7 +39,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public void setOrderByClause(String orderByClause) {
@@ -47,7 +46,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public String getOrderByClause() {
@@ -55,7 +53,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public void setDistinct(boolean distinct) {
@@ -63,7 +60,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public boolean isDistinct() {
@@ -71,7 +67,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public List<Criteria> getOredCriteria() {
@@ -79,7 +74,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public void or(Criteria criteria) {
@@ -87,7 +81,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public Criteria or() {
@@ -97,7 +90,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public Criteria createCriteria() {
@@ -109,7 +101,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
protected Criteria createCriteriaInternal() {
@@ -118,7 +109,6 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public void clear() {
@@ -128,15 +118,13 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public void setLimitStart(Integer limitStart) {
this.limitStart=limitStart;
this.limitStart = limitStart;
}
/**
*
* @mbg.generated
*/
public Integer getLimitStart() {
@@ -144,15 +132,13 @@ public class PluginDOExample {
}
/**
*
* @mbg.generated
*/
public void setLimitEnd(Integer limitEnd) {
this.limitEnd=limitEnd;
this.limitEnd = limitEnd;
}
/**
*
* @mbg.generated
*/
public Integer getLimitEnd() {
@@ -163,6 +149,7 @@ public class PluginDOExample {
* s2_plugin null
*/
protected abstract static class GeneratedCriteria {
protected List<Criterion> criteria;
protected GeneratedCriteria() {
@@ -333,73 +320,73 @@ public class PluginDOExample {
return (Criteria) this;
}
public Criteria andDomainIsNull() {
addCriterion("domain is null");
public Criteria andModelIsNull() {
addCriterion("model is null");
return (Criteria) this;
}
public Criteria andDomainIsNotNull() {
addCriterion("domain is not null");
public Criteria andModelIsNotNull() {
addCriterion("model is not null");
return (Criteria) this;
}
public Criteria andDomainEqualTo(String value) {
addCriterion("domain =", value, "domain");
public Criteria andModelEqualTo(String value) {
addCriterion("model =", value, "model");
return (Criteria) this;
}
public Criteria andDomainNotEqualTo(String value) {
addCriterion("domain <>", value, "domain");
public Criteria andModelNotEqualTo(String value) {
addCriterion("model <>", value, "model");
return (Criteria) this;
}
public Criteria andDomainGreaterThan(String value) {
addCriterion("domain >", value, "domain");
public Criteria andModelGreaterThan(String value) {
addCriterion("model >", value, "model");
return (Criteria) this;
}
public Criteria andDomainGreaterThanOrEqualTo(String value) {
addCriterion("domain >=", value, "domain");
public Criteria andModelGreaterThanOrEqualTo(String value) {
addCriterion("model >=", value, "model");
return (Criteria) this;
}
public Criteria andDomainLessThan(String value) {
addCriterion("domain <", value, "domain");
public Criteria andModelLessThan(String value) {
addCriterion("model <", value, "model");
return (Criteria) this;
}
public Criteria andDomainLessThanOrEqualTo(String value) {
addCriterion("domain <=", value, "domain");
public Criteria andModelLessThanOrEqualTo(String value) {
addCriterion("model <=", value, "model");
return (Criteria) this;
}
public Criteria andDomainLike(String value) {
addCriterion("domain like", value, "domain");
public Criteria andModelLike(String value) {
addCriterion("model like", value, "model");
return (Criteria) this;
}
public Criteria andDomainNotLike(String value) {
addCriterion("domain not like", value, "domain");
public Criteria andModelNotLike(String value) {
addCriterion("model not like", value, "model");
return (Criteria) this;
}
public Criteria andDomainIn(List<String> values) {
addCriterion("domain in", values, "domain");
public Criteria andModelIn(List<String> values) {
addCriterion("model in", values, "model");
return (Criteria) this;
}
public Criteria andDomainNotIn(List<String> values) {
addCriterion("domain not in", values, "domain");
public Criteria andModelNotIn(List<String> values) {
addCriterion("model not in", values, "model");
return (Criteria) this;
}
public Criteria andDomainBetween(String value1, String value2) {
addCriterion("domain between", value1, value2, "domain");
public Criteria andModelBetween(String value1, String value2) {
addCriterion("model between", value1, value2, "model");
return (Criteria) this;
}
public Criteria andDomainNotBetween(String value1, String value2) {
addCriterion("domain not between", value1, value2, "domain");
public Criteria andModelNotBetween(String value1, String value2) {
addCriterion("model not between", value1, value2, "model");
return (Criteria) this;
}
@@ -888,6 +875,7 @@ public class PluginDOExample {
* s2_plugin null
*/
public static class Criterion {
private String condition;
private Object value;

View File

@@ -14,5 +14,5 @@ public interface ChatConfigMapper {
List<ChatConfigDO> search(ChatConfigFilterInternal filterInternal);
ChatConfigDO fetchConfigByDomainId(Long domainId);
ChatConfigDO fetchConfigByModelId(Long modelId);
}

View File

@@ -2,9 +2,8 @@ package com.tencent.supersonic.chat.persistence.mapper;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
import org.apache.ibatis.annotations.Mapper;
import java.util.List;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface ChatQueryDOMapper {

View File

@@ -2,67 +2,58 @@ package com.tencent.supersonic.chat.persistence.mapper;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDO;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
import org.apache.ibatis.annotations.Mapper;
import java.util.List;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface PluginDOMapper {
/**
*
* @mbg.generated
*/
long countByExample(PluginDOExample example);
/**
*
* @mbg.generated
*/
int deleteByPrimaryKey(Long id);
/**
*
* @mbg.generated
*/
int insert(PluginDO record);
/**
*
* @mbg.generated
*/
int insertSelective(PluginDO record);
/**
*
* @mbg.generated
*/
List<PluginDO> selectByExampleWithBLOBs(PluginDOExample example);
/**
*
* @mbg.generated
*/
List<PluginDO> selectByExample(PluginDOExample example);
/**
*
* @mbg.generated
*/
PluginDO selectByPrimaryKey(Long id);
/**
*
* @mbg.generated
*/
int updateByPrimaryKeySelective(PluginDO record);
/**
*
* @mbg.generated
*/
int updateByPrimaryKeyWithBLOBs(PluginDO record);
/**
*
* @mbg.generated
*/
int updateByPrimaryKey(PluginDO record);

View File

@@ -1,10 +1,9 @@
package com.tencent.supersonic.chat.persistence.repository;
import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.config.ChatConfig;
import java.util.List;
public interface ChatConfigRepository {
@@ -15,5 +14,5 @@ public interface ChatConfigRepository {
List<ChatConfigResp> getChatConfig(ChatConfigFilter filter);
ChatConfigResp getConfigByDomainId(Long domainId);
ChatConfigResp getConfigByModelId(Long modelId);
}

View File

@@ -2,11 +2,10 @@ package com.tencent.supersonic.chat.persistence.repository;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
public interface ChatQueryRepository {

View File

@@ -2,10 +2,10 @@ package com.tencent.supersonic.chat.persistence.repository;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDO;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
import java.util.List;
public interface PluginRepository {
List<PluginDO> getPlugins();
List<PluginDO> fetchPluginDOs(String queryText, String type);

View File

@@ -1,14 +1,13 @@
package com.tencent.supersonic.chat.persistence.repository.impl;
import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.config.ChatConfigFilterInternal;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.config.ChatConfigFilterInternal;
import com.tencent.supersonic.chat.persistence.dataobject.ChatConfigDO;
import com.tencent.supersonic.chat.persistence.mapper.ChatConfigMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatConfigRepository;
import com.tencent.supersonic.chat.utils.ChatConfigHelper;
import com.tencent.supersonic.chat.persistence.mapper.ChatConfigMapper;
import java.util.ArrayList;
import java.util.List;
import org.springframework.beans.BeanUtils;
@@ -24,7 +23,7 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
private final ChatConfigMapper chatConfigMapper;
public ChatConfigRepositoryImpl(ChatConfigHelper chatConfigHelper,
ChatConfigMapper chatConfigMapper) {
ChatConfigMapper chatConfigMapper) {
this.chatConfigHelper = chatConfigHelper;
this.chatConfigMapper = chatConfigMapper;
}
@@ -53,15 +52,16 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
List<ChatConfigDO> chaConfigDOList = chatConfigMapper.search(filterInternal);
if (!CollectionUtils.isEmpty(chaConfigDOList)) {
chaConfigDOList.stream().forEach(chaConfigDO ->
chaConfigDescriptorList.add(chatConfigHelper.chatConfigDO2Descriptor(chaConfigDO.getDomainId(), chaConfigDO)));
chaConfigDescriptorList.add(
chatConfigHelper.chatConfigDO2Descriptor(chaConfigDO.getModelId(), chaConfigDO)));
}
return chaConfigDescriptorList;
}
@Override
public ChatConfigResp getConfigByDomainId(Long domainId) {
ChatConfigDO chaConfigPO = chatConfigMapper.fetchConfigByDomainId(domainId);
return chatConfigHelper.chatConfigDO2Descriptor(domainId, chaConfigPO);
public ChatConfigResp getConfigByModelId(Long modelId) {
ChatConfigDO chaConfigPO = chatConfigMapper.fetchConfigByModelId(modelId);
return chatConfigHelper.chatConfigDO2Descriptor(modelId, chaConfigPO);
}
}

View File

@@ -3,13 +3,12 @@ package com.tencent.supersonic.chat.persistence.repository.impl;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample.Criteria;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.common.util.JsonUtil;

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.persistence.repository.impl;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.QueryDO;
import com.tencent.supersonic.chat.persistence.repository.ChatRepository;
import com.tencent.supersonic.chat.persistence.mapper.ChatMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatRepository;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Primary;

View File

@@ -5,13 +5,13 @@ import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
import com.tencent.supersonic.chat.persistence.mapper.PluginDOMapper;
import com.tencent.supersonic.chat.persistence.repository.PluginRepository;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.stereotype.Repository;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.extern.slf4j.Slf4j;
import org.apache.logging.log4j.util.Strings;
import org.springframework.stereotype.Repository;
@Repository
@Slf4j
@@ -58,7 +58,7 @@ public class PluginRepositoryImpl implements PluginRepository {
}
@Override
public void updatePlugin(PluginDO pluginDO){
public void updatePlugin(PluginDO pluginDO) {
pluginDOMapper.updateByPrimaryKeyWithBLOBs(pluginDO);
}
@@ -68,12 +68,12 @@ public class PluginRepositoryImpl implements PluginRepository {
}
@Override
public List<PluginDO> query(PluginDOExample pluginDOExample){
public List<PluginDO> query(PluginDOExample pluginDOExample) {
return pluginDOMapper.selectByExampleWithBLOBs(pluginDOExample);
}
@Override
public void deletePlugin(Long id){
public void deletePlugin(Long id) {
pluginDOMapper.deleteByPrimaryKey(id);
}

View File

@@ -5,10 +5,10 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.common.pojo.RecordInfo;
import java.util.List;
import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
@Data
public class Plugin extends RecordInfo {
@@ -20,7 +20,7 @@ public class Plugin extends RecordInfo {
*/
private String type;
private List<Long> domainList = Lists.newArrayList();
private List<Long> modelList = Lists.newArrayList();
/**
* description, for parsing
@@ -51,8 +51,8 @@ public class Plugin extends RecordInfo {
return Lists.newArrayList();
}
public boolean isContainsAllDomain() {
return CollectionUtils.isNotEmpty(domainList) && domainList.contains(-1L);
public boolean isContainsAllModel() {
return CollectionUtils.isNotEmpty(modelList) && modelList.contains(-1L);
}
}

View File

@@ -2,24 +2,44 @@ package com.tencent.supersonic.chat.plugin;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI;
import java.util.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings;
import org.springframework.context.event.EventListener;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.*;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
@@ -42,7 +62,7 @@ public class PluginManager {
public static List<Plugin> getPlugins() {
PluginService pluginService = ContextUtils.getBean(PluginService.class);
List<Plugin> pluginList = pluginService.getPluginList().stream().filter(plugin ->
CollectionUtils.isNotEmpty(plugin.getDomainList())).collect(Collectors.toList());
CollectionUtils.isNotEmpty(plugin.getModelList())).collect(Collectors.toList());
pluginList.addAll(internalPluginMap.values());
return new ArrayList<>(pluginList);
}
@@ -89,9 +109,9 @@ public class PluginManager {
doRequest(embeddingConfig.getAddPath(), JSONObject.toJSONString(maps));
}
public void doRequest(String path, String jsonBody) {
public ResponseEntity<String> doRequest(String path, String jsonBody) {
if (Strings.isEmpty(embeddingConfig.getUrl())) {
return;
return ResponseEntity.of(Optional.empty());
}
String url = embeddingConfig.getUrl() + path;
HttpHeaders headers = new HttpHeaders();
@@ -105,6 +125,7 @@ public class PluginManager {
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {
});
log.info("[embedding] result body:{}", responseEntity);
return responseEntity;
}
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
@@ -115,7 +136,8 @@ public class PluginManager {
}
public EmbeddingResp recognize(String embeddingText) {
String url = embeddingConfig.getUrl() + embeddingConfig.getRecognizePath() + "?n_results=" + embeddingConfig.getNResult();
String url = embeddingConfig.getUrl() + embeddingConfig.getRecognizePath() + "?n_results="
+ embeddingConfig.getNResult();
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
@@ -125,8 +147,9 @@ public class PluginManager {
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
ResponseEntity<List<EmbeddingResp>> embeddingResponseEntity =
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, new ParameterizedTypeReference<List<EmbeddingResp>>() {
});
restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
new ParameterizedTypeReference<List<EmbeddingResp>>() {
});
log.info("[embedding] recognize result body:{}", embeddingResponseEntity);
List<EmbeddingResp> embeddingResps = embeddingResponseEntity.getBody();
if (CollectionUtils.isNotEmpty(embeddingResps)) {
@@ -178,4 +201,88 @@ public class PluginManager {
return String.valueOf(Integer.parseInt(id) / 1000);
}
public static Pair<Boolean, List<Long>> resolve(Plugin plugin, QueryContext queryContext) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
Set<Long> pluginMatchedModel = getPluginMatchedModel(plugin, queryContext);
if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) {
return Pair.of(false, Lists.newArrayList());
}
List<ParamOption> paramOptions = getSemanticOption(plugin);
if (CollectionUtils.isEmpty(paramOptions)) {
return Pair.of(true, new ArrayList<>(pluginMatchedModel));
}
List<Long> matchedModel = Lists.newArrayList();
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream().
collect(Collectors.groupingBy(ParamOption::getModelId));
for (Long modelId : paramOptionMap.keySet()) {
List<ParamOption> params = paramOptionMap.get(modelId);
if (CollectionUtils.isEmpty(params)) {
matchedModel.add(modelId);
continue;
}
boolean matched = true;
for (ParamOption paramOption : params) {
Set<Long> elementIdSet = getSchemaElementMatch(modelId, schemaMapInfo);
if (CollectionUtils.isEmpty(elementIdSet)) {
matched = false;
break;
}
if (!elementIdSet.contains(paramOption.getElementId())) {
matched = false;
break;
}
}
if (matched) {
matchedModel.add(modelId);
}
}
if (CollectionUtils.isEmpty(matchedModel)) {
return Pair.of(false, Lists.newArrayList());
}
return Pair.of(true, matchedModel);
}
private static Set<Long> getSchemaElementMatch(Long modelId, SchemaMapInfo schemaMapInfo) {
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(modelId);
if (org.springframework.util.CollectionUtils.isEmpty(schemaElementMatches)) {
return Sets.newHashSet();
}
return schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()) ||
SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.map(SchemaElementMatch::getElement)
.map(SchemaElement::getId)
.collect(Collectors.toSet());
}
private static List<ParamOption> getSemanticOption(Plugin plugin) {
WebBase webBase = JSONObject.parseObject(plugin.getConfig(), WebBase.class);
if (Objects.isNull(webBase)) {
return null;
}
List<ParamOption> paramOptions = webBase.getParamOptions();
if (org.springframework.util.CollectionUtils.isEmpty(paramOptions)) {
return Lists.newArrayList();
}
return paramOptions.stream()
.filter(paramOption -> ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType()))
.collect(Collectors.toList());
}
private static Set<Long> getPluginMatchedModel(Plugin plugin, QueryContext queryContext) {
Set<Long> matchedModel = queryContext.getMapInfo().getMatchedModels();
if (plugin.isContainsAllModel()) {
return matchedModel;
}
List<Long> modelIds = plugin.getModelList();
Set<Long> pluginMatchedModel = Sets.newHashSet();
for (Long modelId : modelIds) {
if (matchedModel.contains(modelId)) {
pluginMatchedModel.add(modelId);
}
}
return pluginMatchedModel;
}
}

View File

@@ -2,12 +2,19 @@ package com.tencent.supersonic.chat.plugin;
import com.tencent.supersonic.chat.parser.function.Parameters;
import lombok.Data;
import java.io.Serializable;
import java.util.List;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
@Data
@Builder
@AllArgsConstructor
@ToString
@NoArgsConstructor
public class PluginParseConfig implements Serializable {
private String name;

View File

@@ -0,0 +1,149 @@
package com.tencent.supersonic.chat.query.ContentInterpret;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.pojo.Filter;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.springframework.beans.BeanUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@Slf4j
@Component
public class ContentInterpretQuery extends PluginSemanticQuery {
@Override
public String getQueryMode() {
return "CONTENT_INTERPRET";
}
public ContentInterpretQuery() {
QueryManager.register(this);
}
@Override
public QueryResult execute(User user) throws SqlParseException {
QueryResultWithSchemaResp queryResultWithSchemaResp = queryMetric(user);
String text = generateDataText(queryResultWithSchemaResp);
Map<String, Object> properties = parseInfo.getProperties();
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT))
, PluginParseResult.class);
String answer = fetchInterpret(pluginParseResult.getRequest().getQueryText(), text);
QueryResult queryResult = new QueryResult();
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果", "string", "answer"));
Map<String, Object> result = new HashMap<>();
result.put("answer", answer);
List<Map<String, Object>> resultList = Lists.newArrayList();
resultList.add(result);
queryResultWithSchemaResp.setResultList(resultList);
queryResultWithSchemaResp.setColumns(queryColumns);
queryResult.setResponse(queryResultWithSchemaResp);
queryResult.setQueryMode(getQueryMode());
queryResult.setQueryState(QueryState.SUCCESS);
return queryResult;
}
private QueryResultWithSchemaResp queryMetric(User user) {
SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
QueryStructReq queryStructReq = new QueryStructReq();
queryStructReq.setModelId(parseInfo.getModelId());
queryStructReq.setGroups(Lists.newArrayList(TimeDimensionEnum.DAY.getName()));
ModelSchema modelSchema = semanticLayer.getModelSchema(parseInfo.getModelId(), true);
queryStructReq.setAggregators(buildAggregator(modelSchema));
List<Filter> filterList = Lists.newArrayList();
for (QueryFilter queryFilter : parseInfo.getDimensionFilters()) {
Filter filter = new Filter();
BeanUtils.copyProperties(queryFilter, filter);
filterList.add(filter);
}
queryStructReq.setDimensionFilters(filterList);
DateConf dateConf = new DateConf();
dateConf.setDateMode(DateConf.DateMode.RECENT);
dateConf.setUnit(7);
queryStructReq.setDateInfo(dateConf);
return semanticLayer.queryByStruct(queryStructReq, user);
}
private List<Aggregator> buildAggregator(ModelSchema modelSchema) {
List<Aggregator> aggregators = Lists.newArrayList();
Set<SchemaElement> metrics = modelSchema.getMetrics();
if (CollectionUtils.isEmpty(metrics)) {
return aggregators;
}
for (SchemaElement schemaElement : metrics) {
Aggregator aggregator = new Aggregator();
aggregator.setColumn(schemaElement.getBizName());
aggregator.setFunc(AggOperatorEnum.SUM);
aggregator.setNameCh(schemaElement.getName());
aggregators.add(aggregator);
}
return aggregators;
}
public String generateDataText(QueryResultWithSchemaResp queryResultWithSchemaResp) {
Map<String, String> map = queryResultWithSchemaResp.getColumns().stream()
.collect(Collectors.toMap(QueryColumn::getNameEn, QueryColumn::getName));
StringBuilder stringBuilder = new StringBuilder();
for (Map<String, Object> valueMap : queryResultWithSchemaResp.getResultList()) {
for (String key : valueMap.keySet()) {
String name = "";
if (TimeDimensionEnum.getNameList().contains(key)) {
name = "日期";
} else {
name = map.get(key);
}
String value = String.valueOf(valueMap.get(key));
stringBuilder.append(name).append(":").append(value).append(" ");
}
}
return stringBuilder.toString();
}
public String fetchInterpret(String queryText, String dataText) {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
LLmAnswerReq lLmAnswerReq = new LLmAnswerReq();
lLmAnswerReq.setQueryText(queryText);
lLmAnswerReq.setPluginOutput(dataText);
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
JSONObject.toJSONString(lLmAnswerReq));
LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class);
if (lLmAnswerResp != null) {
return lLmAnswerResp.getAssistant_message();
}
return null;
}
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.query.ContentInterpret;
import lombok.Data;
@Data
public class LLmAnswerReq {
private String queryText;
private String pluginOutput;
}

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.chat.query.ContentInterpret;
import lombok.Data;
@Data
public class LLmAnswerResp {
private String assistant_message;
}

View File

@@ -2,20 +2,19 @@ package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.pojo.Constants;
import java.util.*;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery;
import java.util.ArrayList;
import java.util.List;
import java.util.OptionalDouble;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class HeuristicQuerySelector implements QuerySelector {
private static final double MATCH_INHERIT_PENALTY = 0.5;
private static final double MATCH_CURRENT_REWORD = 2;
private static final double CANDIDATE_THRESHOLD = 0.2;
@Override
@@ -26,49 +25,53 @@ public class HeuristicQuerySelector implements QuerySelector {
selectedQueries.addAll(candidateQueries);
} else {
OptionalDouble maxScoreOp = candidateQueries.stream().mapToDouble(
q -> computeScore(q.getParseInfo())).max();
q -> q.getParseInfo().getScore()).max();
if (maxScoreOp.isPresent()) {
double maxScore = maxScoreOp.getAsDouble();
candidateQueries.stream().forEach(query -> {
SemanticParseInfo semanticParse = query.getParseInfo();
if ((maxScore - semanticParse.getScore()) / maxScore <= CANDIDATE_THRESHOLD) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!checkFullyInherited(query)
&& (maxScore - parseInfo.getScore()) / maxScore <= CANDIDATE_THRESHOLD
&& checkSatisfyOtherRules(query, candidateQueries)) {
selectedQueries.add(query);
}
log.info("candidate query (domain={}, queryMode={}) with score={}",
semanticParse.getDomainName(), semanticParse.getQueryMode(), semanticParse.getScore());
log.info("candidate query (Model={}, queryMode={}) with score={}",
parseInfo.getModelName(), parseInfo.getQueryMode(), parseInfo.getScore());
});
}
}
return selectedQueries;
}
private double computeScore(SemanticParseInfo semanticParse) {
double totalScore = 0;
Map<SchemaElementType, SchemaElementMatch> maxSimilarityMatch = new HashMap<>();
for (SchemaElementMatch match : semanticParse.getElementMatches()) {
SchemaElementType type = match.getElement().getType();
if (!maxSimilarityMatch.containsKey(type) ||
match.getSimilarity() > maxSimilarityMatch.get(type).getSimilarity()) {
maxSimilarityMatch.put(type, match);
private boolean checkSatisfyOtherRules(SemanticQuery semanticQuery, List<SemanticQuery> candidateQueries) {
if (!semanticQuery.getQueryMode().equals(MetricModelQuery.QUERY_MODE)) {
return true;
}
for (SemanticQuery candidateQuery : candidateQueries) {
if (candidateQuery.getQueryMode().equals(MetricEntityQuery.QUERY_MODE) &&
semanticQuery.getParseInfo().getScore() == candidateQuery.getParseInfo().getScore()) {
return false;
}
}
return true;
}
for (SchemaElementMatch match : maxSimilarityMatch.values()) {
double matchScore = Optional.ofNullable(match.getDetectWord()).orElse(Constants.EMPTY).length() * match.getSimilarity();
if (match.equals(SchemaElementMatch.MatchMode.INHERIT)) {
matchScore *= MATCH_INHERIT_PENALTY;
} else {
matchScore *= MATCH_CURRENT_REWORD;
}
totalScore += matchScore;
private boolean checkFullyInherited(SemanticQuery query) {
SemanticParseInfo parseInfo = query.getParseInfo();
if (!(query instanceof RuleSemanticQuery)) {
return false;
}
// original score in parse info acts like an extra bonus
totalScore += semanticParse.getScore();
semanticParse.setScore(totalScore);
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
if (!match.isInherited()) {
return false;
}
}
if (parseInfo.getDateInfo() != null && !parseInfo.getDateInfo().isInherited()) {
return false;
}
return totalScore;
return true;
}
}

View File

@@ -5,7 +5,6 @@ import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.entity.EntitySemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -56,6 +55,7 @@ public class QueryManager {
throw new RuntimeException("no supported queryMode :" + queryMode);
}
}
public static boolean containsRuleQuery(String queryMode) {
if (queryMode == null) {
return false;

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import java.util.List;
/**

View File

@@ -0,0 +1,78 @@
package com.tencent.supersonic.chat.query.dsl;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class DSLBuilder {
public static final String DATA_Field = "数据日期";
public static final String TABLE_PREFIX = "t_";
public String build(SemanticParseInfo parseInfo, QueryFilters queryFilters, LLMResp llmResp, Long modelId)
throws Exception {
String sqlOutput = llmResp.getSqlOutput();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
List<SchemaElement> dbAllFields = new ArrayList<>();
dbAllFields.addAll(semanticSchema.getMetrics());
dbAllFields.addAll(semanticSchema.getDimensions());
Map<String, String> fieldToBizName = getMapInfo(modelId, dbAllFields);
fieldToBizName.put(DATA_Field, TimeDimensionEnum.DAY.getName());
sqlOutput = CCJSqlParserUtils.replaceFields(sqlOutput, fieldToBizName);
sqlOutput = CCJSqlParserUtils.replaceTable(sqlOutput, TABLE_PREFIX + modelId);
String queryFilter = getQueryFilter(queryFilters);
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to sql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
CCJSqlParserUtils.addWhere(sqlOutput, expression);
}
log.info("build sqlOutput:{}", sqlOutput);
return sqlOutput;
}
protected Map<String, String> getMapInfo(Long modelId, List<SchemaElement> metrics) {
return metrics.stream().filter(entry -> entry.getModel().equals(modelId))
.collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1));
}
private String getQueryFilter(QueryFilters queryFilters) {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return "";
}
List<QueryFilter> filters = queryFilters.getFilters();
return filters.stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
})
.collect(Collectors.joining(Constants.AND_UPPER));
}
}

View File

@@ -0,0 +1,96 @@
package com.tencent.supersonic.chat.query.dsl;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.parser.llm.DSLParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryDslReq;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@Slf4j
@Component
public class DSLQuery extends PluginSemanticQuery {
public static final String QUERY_MODE = "DSL";
private DSLBuilder dslBuilder = new DSLBuilder();
protected SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer();
public DSLQuery() {
QueryManager.register(this);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
@Override
public QueryResult execute(User user) {
String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT));
DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class);
LLMResp llmResp = dslParseResult.getLlmResp();
QueryReq queryReq = dslParseResult.getRequest();
Long modelId = parseInfo.getModelId();
String querySql = convertToSql(queryReq.getQueryFilters(), llmResp, parseInfo, modelId);
long startTime = System.currentTimeMillis();
QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(querySql, modelId);
QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(queryDslReq, user);
log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
QueryResult queryResult = new QueryResult();
if (Objects.nonNull(queryResp)) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
}
String resultQql = queryResp == null ? null : queryResp.getSql();
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>() : queryResp.getResultList();
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
queryResult.setQuerySql(resultQql);
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(columns);
queryResult.setQueryMode(QUERY_MODE);
queryResult.setQueryState(QueryState.SUCCESS);
// add model info
EntityInfo entityInfo = ContextUtils.getBean(SemanticService.class).getEntityInfo(parseInfo, user);
queryResult.setEntityInfo(entityInfo);
parseInfo.setProperties(null);
return queryResult;
}
protected String convertToSql(QueryFilters queryFilters, LLMResp llmResp, SemanticParseInfo parseInfo,
Long modelId) {
try {
return dslBuilder.build(parseInfo, queryFilters, llmResp, modelId);
} catch (Exception e) {
log.error("convertToSql error", e);
}
return null;
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.plugin.dsl;
package com.tencent.supersonic.chat.query.dsl;
import java.util.List;
import lombok.Data;
@@ -12,6 +12,8 @@ public class LLMReq {
private List<ElementValue> linking;
private String currentDate;
@Data
public static class ElementValue {
@@ -26,6 +28,8 @@ public class LLMReq {
private String domainName;
private String modelName;
private List<String> fieldNameList;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.query.plugin.dsl;
package com.tencent.supersonic.chat.query.dsl;
import java.util.List;
import lombok.Data;
@@ -8,7 +8,7 @@ public class LLMResp {
private String query;
private String domainName;
private String modelName;
private String sqlOutput;

View File

@@ -15,7 +15,7 @@ public class ParamOption {
private String keyAlias;
private Long domainId;
private Long modelId;
private Long elementId;

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.query.plugin;
import com.google.common.collect.Lists;
import lombok.Data;
import java.util.List;
import lombok.Data;
@Data
public class WebBase {

Some files were not shown because too many files have changed in this diff Show More