[improvement][project] supersonic 0.7.0 version backend update (#24)

* [improvement][project] supersonic 0.7.0 version backend update

* [improvement][project] supersonic 0.7.0 version backend update

* [improvement][project] supersonic 0.7.0 version readme update

---------

Co-authored-by: jolunoluo <jolunoluo@tencent.com>
This commit is contained in:
SunDean
2023-08-05 22:17:56 +08:00
committed by GitHub
parent 6951eada9d
commit aa0a100a85
184 changed files with 2609 additions and 1238 deletions

View File

@@ -16,4 +16,6 @@ public interface SemanticQuery {
QueryResult execute(User user) throws SqlParseException;
SemanticParseInfo getParseInfo();
void setParseInfo(SemanticParseInfo parseInfo);
}

View File

@@ -13,11 +13,15 @@ public class DomainSchema {
private Set<SchemaElement> metrics = new HashSet<>();
private Set<SchemaElement> dimensions = new HashSet<>();
private Set<SchemaElement> dimensionValues = new HashSet<>();
private Set<SchemaElement> entities = new HashSet<>();
private SchemaElement entity = new SchemaElement();
public SchemaElement getElement(SchemaElementType elementType, long elementID) {
Optional<SchemaElement> element = Optional.empty();
switch (elementType) {
case ENTITY:
element = Optional.ofNullable(entity);
break;
case DOMAIN:
element = Optional.of(domain);
break;
@@ -27,9 +31,6 @@ public class DomainSchema {
case DIMENSION:
element = dimensions.stream().filter(e -> e.getId() == elementID).findFirst();
break;
case ENTITY:
element = entities.stream().filter(e -> e.getId() == elementID).findFirst();
break;
case VALUE:
element = dimensionValues.stream().filter(e -> e.getId() == elementID).findFirst();
default:

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.api.pojo;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.request.QueryRequest;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import lombok.Data;
import java.util.ArrayList;
@@ -10,11 +10,11 @@ import java.util.List;
@Data
public class QueryContext {
private QueryRequest request;
private QueryReq request;
private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo();
public QueryContext(QueryRequest request) {
public QueryContext(QueryReq request) {
this.request = request;
}
}

View File

@@ -5,11 +5,13 @@ import com.google.common.base.Objects;
import java.io.Serializable;
import java.util.List;
import lombok.Builder;
import lombok.Data;
import lombok.*;
@Data
@Getter
@Builder
@NoArgsConstructor
//@AllArgsConstructor
public class SchemaElement implements Serializable {
private Long domain;
@@ -20,8 +22,8 @@ public class SchemaElement implements Serializable {
private SchemaElementType type;
private List<String> alias;
public SchemaElement() {
}
// public SchemaElement() {
// }
public SchemaElement(Long domain, Long id, String name, String bizName,
Long useCnt, SchemaElementType type, List<String> alias) {

View File

@@ -18,4 +18,10 @@ public class SchemaElementMatch {
String detectWord;
String word;
Long frequency;
MatchMode mode = MatchMode.CURRENT;
public enum MatchMode {
CURRENT,
INHERIT
}
}

View File

@@ -1,36 +1,32 @@
package com.tencent.supersonic.chat.api.pojo;
import java.util.*;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.Data;
@Data
public class SemanticParseInfo {
String queryMode;
SchemaElement domain;
Set<SchemaElement> metrics = new LinkedHashSet();
Set<SchemaElement> dimensions = new LinkedHashSet();
Long entity = 0L;
AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
Set<QueryFilter> dimensionFilters = new LinkedHashSet();
Set<QueryFilter> metricFilters = new LinkedHashSet();
private String queryMode;
private SchemaElement domain;
private Set<SchemaElement> metrics = new TreeSet<>(new SchemaNameLengthComparator());
private Set<SchemaElement> dimensions = new LinkedHashSet();
private SchemaElement entity;
private AggregateTypeEnum aggType = AggregateTypeEnum.NONE;
private Set<QueryFilter> dimensionFilters = new LinkedHashSet();
private Set<QueryFilter> metricFilters = new LinkedHashSet();
private Set<Order> orders = new LinkedHashSet();
private DateConf dateInfo;
private Long limit;
private Boolean nativeQuery = false;
private Double bonus = 0d;
private double score;
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
private Map<String, Object> properties;
private Map<String, Object> properties = new HashMap<>();
public Long getDomainId() {
return domain != null ? domain.getId() : 0L;
@@ -40,8 +36,9 @@ public class SemanticParseInfo {
return domain != null ? domain.getName() : "null";
}
public Set<SchemaElement> getMetrics() {
this.metrics = this.metrics.stream().sorted((o1, o2) -> {
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
@Override
public int compare(SchemaElement o1, SchemaElement o2) {
int len1 = o1.getName().length();
int len2 = o2.getName().length();
if (len1 != len2) {
@@ -49,7 +46,7 @@ public class SemanticParseInfo {
} else {
return o1.getName().compareTo(o2.getName());
}
}).collect(Collectors.toCollection(LinkedHashSet::new));
return this.metrics;
}
}
}

View File

@@ -48,7 +48,7 @@ public class SemanticSchema implements Serializable {
public List<SchemaElement> getEntities() {
List<SchemaElement> entities = new ArrayList<>();
domainSchemaList.stream().forEach(d -> entities.addAll(d.getEntities()));
domainSchemaList.stream().forEach(d -> entities.add(d.getEntity()));
return entities;
}
}

View File

@@ -1,11 +1,11 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
import java.util.List;
@Data
public class ChatAggConfig {
public class ChatAggConfigReq {
/**
* invisible dimensions/metrics
@@ -15,10 +15,10 @@ public class ChatAggConfig {
/**
* information about dictionary about the domain
*/
private List<KnowledgeInfo> knowledgeInfos;
private List<KnowledgeInfoReq> knowledgeInfos;
private KnowledgeAdvancedConfig globalKnowledgeConfig;
private ChatDefaultConfig chatDefaultConfig;
private ChatDefaultConfigReq chatDefaultConfig;
}

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestion;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import lombok.Data;
@@ -20,18 +19,18 @@ public class ChatConfigBaseReq {
/**
* the chatDetailConfig about the domain
*/
private ChatDetailConfig chatDetailConfig;
private ChatDetailConfigReq chatDetailConfig;
/**
* the chatAggConfig about the domain
*/
private ChatAggConfig chatAggConfig;
private ChatAggConfigReq chatAggConfig;
/**
* the recommended questions about the domain
*/
private List<RecommendedQuestion> recommendedQuestions;
private List<RecommendedQuestionReq> recommendedQuestions;
/**
* available status

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.Constants;
@@ -8,7 +8,7 @@ import java.util.ArrayList;
import java.util.List;
@Data
public class ChatDefaultConfig {
public class ChatDefaultConfigReq {
private List<Long> dimensionIds = new ArrayList<>();
private List<Long> metricIds = new ArrayList<>();
@@ -24,4 +24,15 @@ public class ChatDefaultConfig {
*/
private String period = Constants.DAY;
private TimeMode timeMode = TimeMode.LAST;
public enum TimeMode {
/**
* date mode
* LAST - a certain time
* RECENT - a period time
*/
LAST, RECENT
}
}

View File

@@ -1,11 +1,11 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;
import java.util.List;
@Data
public class ChatDetailConfig {
public class ChatDetailConfigReq {
/**
* invisible dimensions/metrics
@@ -15,15 +15,10 @@ public class ChatDetailConfig {
/**
* information about dictionary about the domain
*/
private List<KnowledgeInfo> knowledgeInfos;
private List<KnowledgeInfoReq> knowledgeInfos;
private KnowledgeAdvancedConfig globalKnowledgeConfig;
private ChatDefaultConfig chatDefaultConfig;
/**
* the entity info about the domain
*/
private Entity entity;
private ChatDefaultConfigReq chatDefaultConfig;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.request;
import java.util.List;
import lombok.AllArgsConstructor;

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.Data;
@Data
public class ExecuteQueryReq {
private User user;
private Integer chatId;
private String queryText;
private SemanticParseInfo parseInfo;
private boolean saveAnswer = true;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.request;
import java.util.ArrayList;
import java.util.List;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.request;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
@@ -11,7 +11,7 @@ import lombok.Data;
*/
@Data
public class KnowledgeInfo {
public class KnowledgeInfoReq {
/**
* metricIdDimensionIddomainId

View File

@@ -7,10 +7,9 @@ import lombok.Data;
public class PluginQueryReq {
private String showElementId;
private String name;
//DASHBOARD WIDGET
private String showType;
private String parseMode;
private String type;
@@ -18,5 +17,5 @@ public class PluginQueryReq {
private String pattern;
private String createdBy;
}

View File

@@ -10,7 +10,7 @@ import com.tencent.supersonic.common.pojo.Order;
import lombok.Data;
@Data
public class QueryDataRequest {
public class QueryDataReq {
String queryMode;
SchemaElement domain;
Set<SchemaElement> metrics = new HashSet<>();

View File

@@ -4,8 +4,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import lombok.Data;
@Data
public class QueryRequest {
public class QueryReq {
private String queryText;
private Integer chatId;
private Long domainId = 0L;

View File

@@ -1,11 +1,15 @@
package com.tencent.supersonic.chat.api.pojo.request;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
@Data
@ToString
public class RecommendedQuestion {
@AllArgsConstructor
@NoArgsConstructor
public class RecommendedQuestionReq {
private String question;

View File

@@ -0,0 +1,26 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeAdvancedConfig;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
import lombok.Data;
import java.util.List;
@Data
public class ChatAggRichConfigResp {
/**
* invisible dimensions/metrics
*/
private ItemVisibilityInfo visibility;
/**
* information about dictionary about the domain
*/
private List<KnowledgeInfoReq> knowledgeInfos;
private KnowledgeAdvancedConfig globalKnowledgeConfig;
private ChatDefaultRichConfigResp chatDefaultConfig;
}

View File

@@ -1,6 +1,8 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestion;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import java.util.Date;
@@ -15,11 +17,11 @@ public class ChatConfigResp {
private Long domainId;
private ChatDetailConfig chatDetailConfig;
private ChatDetailConfigReq chatDetailConfig;
private ChatAggConfig chatAggConfig;
private ChatAggConfigReq chatAggConfig;
private List<RecommendedQuestion> recommendedQuestions;
private List<RecommendedQuestionReq> recommendedQuestions;
/**
* available status

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestion;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import java.util.Date;
import java.util.List;
@@ -8,7 +8,7 @@ import java.util.List;
import lombok.Data;
@Data
public class ChatConfigRich {
public class ChatConfigRichResp {
private Long id;
@@ -17,11 +17,11 @@ public class ChatConfigRich {
private String domainName;
private String bizName;
private ChatAggRichConfig chatAggRichConfig;
private ChatAggRichConfigResp chatAggRichConfig;
private ChatDetailRichConfig chatDetailRichConfig;
private ChatDetailRichConfigResp chatDetailRichConfig;
private List<RecommendedQuestion> recommendedQuestions;
private List<RecommendedQuestionReq> recommendedQuestions;
/**
* available status

View File

@@ -1,14 +1,15 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
import com.tencent.supersonic.common.pojo.Constants;
import lombok.Data;
import java.util.List;
@Data
public class ChatDefaultRichConfig {
public class ChatDefaultRichConfigResp {
private List<SchemaElement> dimensions;
private List<SchemaElement> metrics;
@@ -25,4 +26,6 @@ public class ChatDefaultRichConfig {
*/
private String period = Constants.DAY;
private ChatDefaultConfigReq.TimeMode timeMode;
}

View File

@@ -0,0 +1,27 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeAdvancedConfig;
import com.tencent.supersonic.chat.api.pojo.request.KnowledgeInfoReq;
import lombok.Data;
import java.util.List;
@Data
public class ChatDetailRichConfigResp {
/**
* invisible dimensions/metrics
*/
private ItemVisibilityInfo visibility;
/**
* information about dictionary about the domain
*/
private List<KnowledgeInfoReq> knowledgeInfos;
private KnowledgeAdvancedConfig globalKnowledgeConfig;
private ChatDefaultRichConfigResp chatDefaultConfig;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
@@ -6,7 +6,7 @@ import java.util.List;
import lombok.Data;
@Data
public class EntityRichInfo {
public class EntityRichInfoResp {
/**
* entity alias
*/

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.chat.config;
package com.tencent.supersonic.chat.api.pojo.response;
import java.util.List;
import lombok.Data;

View File

@@ -7,6 +7,7 @@ import lombok.Data;
public class MetricInfo {
private String name;
private String dimension;
private String value;
private String date;
private Map<String, String> statistics;

View File

@@ -0,0 +1,25 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import lombok.*;
import java.util.List;
@Data
@Getter
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ParseResp {
private Integer chatId;
private String queryText;
private ParseState state;
private List<SemanticParseInfo> selectedParses;
private List<SemanticParseInfo> candidateParses;
public enum ParseState {
COMPLETED,
PENDING,
FAILED
}
}

View File

@@ -4,7 +4,7 @@ import java.util.Date;
import lombok.Data;
@Data
public class QueryResponse {
public class QueryResp {
private Long questionId;
private Date createTime;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.api.pojo.response;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestion;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import lombok.AllArgsConstructor;
import lombok.Data;
@@ -8,7 +8,7 @@ import java.util.List;
@Data
@AllArgsConstructor
public class RecommendQuestion {
public class RecommendQuestionResp {
private Long domainId;
private List<RecommendedQuestion> recommendedQuestions;
private List<RecommendedQuestionReq> recommendedQuestions;
}

View File

@@ -6,7 +6,7 @@ import lombok.Data;
import java.util.List;
@Data
public class RecommendResponse {
public class RecommendResp {
private List<SchemaElement> dimensions;
private List<SchemaElement> metrics;
}

View File

@@ -6,11 +6,11 @@ import lombok.Getter;
import lombok.Setter;
@Data
public class SearchResponse {
public class SearchResp {
private List<SearchResult> searchResults;
public SearchResponse(List<SearchResult> searchResults) {
public SearchResp(List<SearchResult> searchResults) {
this.searchResults = searchResults;
}
}

View File

@@ -1,7 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>chat</artifactId>
<groupId>com.tencent.supersonic</groupId>
@@ -40,6 +40,11 @@
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.github.plexpt</groupId>
<artifactId>chatgpt</artifactId>
<version>4.1.2</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.config;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class AggregatorConfig {
@Value("${metric.aggregator.ratio.enable:true}")
private Boolean enableRatio;
}

View File

@@ -1,24 +0,0 @@
package com.tencent.supersonic.chat.config;
import lombok.Data;
import java.util.List;
@Data
public class ChatAggRichConfig {
/**
* invisible dimensions/metrics
*/
private ItemVisibilityInfo visibility;
/**
* information about dictionary about the domain
*/
private List<KnowledgeInfo> knowledgeInfos;
private KnowledgeAdvancedConfig globalKnowledgeConfig;
private ChatDefaultRichConfig chatDefaultConfig;
}

View File

@@ -1,6 +1,8 @@
package com.tencent.supersonic.chat.config;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestion;
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatDetailConfigReq;
import com.tencent.supersonic.chat.api.pojo.request.RecommendedQuestionReq;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
@@ -22,14 +24,14 @@ public class ChatConfig {
/**
* the chatDetailConfig about the domain
*/
private ChatDetailConfig chatDetailConfig;
private ChatDetailConfigReq chatDetailConfig;
/**
* the chatAggConfig about the domain
*/
private ChatAggConfig chatAggConfig;
private ChatAggConfigReq chatAggConfig;
private List<RecommendedQuestion> recommendedQuestions;
private List<RecommendedQuestionReq> recommendedQuestions;
/**
* available status

View File

@@ -1,30 +0,0 @@
package com.tencent.supersonic.chat.config;
import lombok.Data;
import java.util.List;
@Data
public class ChatDetailRichConfig {
/**
* invisible dimensions/metrics
*/
private ItemVisibilityInfo visibility;
/**
* the entity info about the domain
*/
private EntityRichInfo entity;
/**
* information about dictionary about the domain
*/
private List<KnowledgeInfo> knowledgeInfos;
private KnowledgeAdvancedConfig globalKnowledgeConfig;
private ChatDefaultRichConfig chatDefaultConfig;
}

View File

@@ -6,7 +6,7 @@ import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class FunctionCallConfig {
public class FunctionCallInfoConfig {
@Value("${functionCall.url:}")
private String url;
}

View File

@@ -0,0 +1,69 @@
package com.tencent.supersonic.chat.mapper;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
public class EntityMapper implements SchemaMapper {
@Override
public void map(QueryContext queryContext) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
for (Long domainId : schemaMapInfo.getMatchedDomains()) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(domainId);
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
continue;
}
SchemaElement entity = getEntity(domainId);
if (entity == null || entity.getId() == null) {
continue;
}
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (!entity.getId().equals(schemaElementMatch.getElement().getId())){
continue;
}
if (!checkExistSameEntitySchemaElements(schemaElementMatch, schemaElementMatchList)) {
SchemaElementMatch entitySchemaElementMath = new SchemaElementMatch();
BeanUtils.copyProperties(schemaElementMatch, entitySchemaElementMath);
entitySchemaElementMath.setElement(entity);
schemaElementMatchList.add(entitySchemaElementMath);
}
schemaElementMatch.getElement().setType(SchemaElementType.ID);
}
}
}
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
List<SchemaElementMatch> schemaElementMatchList) {
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : entitySchemaElements) {
if (schemaElementMatch.getElement().getId().equals(valueSchemaElementMatch.getElement().getId())) {
return true;
}
}
return false;
}
private SchemaElement getEntity(Long domainId) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = semanticService.getDomainSchema(domainId);
if (domainSchema != null && domainSchema.getEntity() != null) {
return domainSchema.getEntity();
}
return null;
}
}

View File

@@ -2,13 +2,9 @@ package com.tencent.supersonic.chat.mapper;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import java.util.ArrayList;
@@ -44,7 +40,7 @@ public class FuzzyNameMapper implements SchemaMapper {
}
private void detectAndAddToSchema(QueryContext queryContext, List<Term> terms, List<SchemaElement> domains,
SchemaElementType schemaElementType) {
SchemaElementType schemaElementType) {
try {
Map<String, Set<SchemaElement>> domainResultSet = getResultSet(queryContext, terms, domains);
@@ -57,7 +53,7 @@ public class FuzzyNameMapper implements SchemaMapper {
}
private Map<String, Set<SchemaElement>> getResultSet(QueryContext queryContext, List<Term> terms,
List<SchemaElement> domains) {
List<SchemaElement> domains) {
String queryText = queryContext.getRequest().getQueryText();

View File

@@ -74,6 +74,10 @@ public class HanlpDictMapper implements SchemaMapper {
Long frequency = wordNatureToFrequency.get(mapResult.getName() + nature);
SchemaElement element = domainSchema.getElement(elementType, elementID);
if(Objects.isNull(element)){
log.info("element is null, elementType:{},elementID:{}", elementType, elementID);
continue;
}
if (element.getType().equals(SchemaElementType.VALUE)) {
element.setName(mapResult.getName());
}

View File

@@ -1,27 +1,41 @@
package com.tencent.supersonic.chat.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.component.SchemaMapper;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.request.QueryRequest;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j
public class QueryFilterMapper implements SchemaMapper {
private Long FREQUENCY = 9999999L;
private double SIMILARITY = 1.0;
@Override
public void map(QueryContext queryContext) {
QueryRequest queryReq = queryContext.getRequest();
QueryReq queryReq = queryContext.getRequest();
Long domainId = queryReq.getDomainId();
if (domainId == null || domainId <= 0) {
return;
}
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
clearOtherSchemaElementMatch(domainId, schemaMapInfo);
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(domainId);
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
schemaMapInfo.setMatchedElements(domainId, schemaElementMatches);
}
addValueSchemaElementMatch(schemaElementMatches, queryReq.getQueryFilters());
}
private void clearOtherSchemaElementMatch(Long domainId, SchemaMapInfo schemaMapInfo) {
private void clearOtherSchemaElementMatch(Long domainId, SchemaMapInfo schemaMapInfo) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getDomainElementMatches().entrySet()) {
if (!entry.getKey().equals(domainId)) {
entry.getValue().clear();
@@ -29,4 +43,44 @@ public class QueryFilterMapper implements SchemaMapper {
}
}
private List<SchemaElementMatch> addValueSchemaElementMatch(List<SchemaElementMatch> candidateElementMatches,
QueryFilters queryFilter) {
if (queryFilter == null || CollectionUtils.isEmpty(queryFilter.getFilters())) {
return candidateElementMatches;
}
for (QueryFilter filter : queryFilter.getFilters()) {
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
continue;
}
SchemaElement element = SchemaElement.builder()
.id(filter.getElementID())
.name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE)
.bizName(filter.getBizName())
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
.frequency(FREQUENCY)
.word(String.valueOf(filter.getValue()))
.similarity(SIMILARITY)
.detectWord(filter.getName())
.build();
candidateElementMatches.add(schemaElementMatch);
}
return candidateElementMatches;
}
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
List<SchemaElementMatch> schemaElementMatches) {
List<SchemaElementMatch> valueSchemaElements = schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (schemaElementMatch.getElement().getId().equals(queryFilter.getElementID())
&& schemaElementMatch.getWord().equals(String.valueOf(queryFilter.getValue()))) {
return true;
}
}
return false;
}
}

View File

@@ -3,13 +3,16 @@ package com.tencent.supersonic.chat.parser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.compress.utils.Lists;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
@@ -25,31 +28,36 @@ import java.util.stream.Collectors;
public class SatisfactionChecker {
private static final double LONG_TEXT_THRESHOLD = 0.8;
private static final double SHORT_TEXT_THRESHOLD = 0.6;
private static final double SHORT_TEXT_THRESHOLD = 0.5;
private static final int QUERY_TEXT_LENGTH_THRESHOLD = 10;
public static final double BONUS_THRESHOLD = 100;
public static final double EMBEDDING_THRESHOLD = 0.2;
// check all the parse info in candidate
public static boolean check(QueryContext queryCtx) {
for (SemanticQuery query : queryCtx.getCandidateQueries()) {
SemanticParseInfo semanticParseInfo = query.getParseInfo();
Long domainId = semanticParseInfo.getDomainId();
List<SchemaElementMatch> schemaElementMatches = queryCtx.getMapInfo()
.getMatchedElements(domainId);
if (check(queryCtx.getRequest().getQueryText(), semanticParseInfo, schemaElementMatches)) {
return true;
if (query instanceof RuleSemanticQuery) {
if (checkRuleThreshHold(queryCtx.getRequest().getQueryText(), query.getParseInfo())) {
return true;
}
} else if (query instanceof PluginSemanticQuery) {
if (checkEmbeddingThreshold(query.getParseInfo())) {
log.info("query mode :{} satisfy check", query.getQueryMode());
return true;
}
}
}
return false;
}
private static boolean checkEmbeddingThreshold(SemanticParseInfo semanticParseInfo) {
Object object = semanticParseInfo.getProperties().get(Constants.CONTEXT);
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(object), PluginParseResult.class);
return EMBEDDING_THRESHOLD > pluginParseResult.getDistance();
}
//check single parse info
private static boolean check(String text, SemanticParseInfo semanticParseInfo,
List<SchemaElementMatch> schemaElementMatches) {
if (semanticParseInfo.getBonus() != null && semanticParseInfo.getBonus() >= BONUS_THRESHOLD) {
return true;
}
private static boolean checkRuleThreshHold(String text, SemanticParseInfo semanticParseInfo) {
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
if (CollectionUtils.isEmpty(schemaElementMatches)) {
return false;
}
@@ -71,6 +79,11 @@ public class SatisfactionChecker {
detectWords.add(schemaElementMatch.getDetectWord());
}
}
for (SchemaElementMatch schemaElementMatch : schemaElementMatches) {
if (SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())) {
detectWords.add(schemaElementMatch.getDetectWord());
}
}
for (SchemaElement schemaItem : semanticParseInfo.getMetrics()) {
detectWords.add(
detectWordMap.getOrDefault(Optional.ofNullable(schemaItem.getId()).orElse(0L), ""));
@@ -87,7 +100,7 @@ public class SatisfactionChecker {
if (StringUtils.isNotBlank(dateText) && !dateText.equalsIgnoreCase(Constants.NULL)) {
detectWords.add(dateText);
}
detectWords.removeIf(word -> !text.contains(word));
detectWords.removeIf(word -> !text.contains(word) && !text.contains(StringUtils.reverse(word)));
//compare the length between detect words and query text
return checkThreshold(text, detectWords, semanticParseInfo);
}

View File

@@ -5,19 +5,21 @@ import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.config.ChatConfigRich;
import com.tencent.supersonic.chat.config.EntityRichInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.*;
import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils;
@@ -30,30 +32,43 @@ public class EmbeddingBasedParser implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
if (SatisfactionChecker.check(queryContext) || StringUtils.isBlank(embeddingConfig.getUrl())) {
if (StringUtils.isBlank(embeddingConfig.getUrl())) {
return;
}
log.info("EmbeddingBasedParser parser query ctx: {}, chat ctx: {}", queryContext, chatContext);
for (Long domainId : getDomainMatched(queryContext)) {
String text = replaceText(queryContext, domainId);
List<RecallRetrieval> embeddingRetrievals = recallResult(text, hasCandidateQuery(queryContext));
Optional<Plugin> pluginOptional = choosePlugin(embeddingRetrievals, domainId);
if (pluginOptional.isPresent()) {
Map<String, RecallRetrieval> embeddingRetrievalMap = embeddingRetrievals.stream()
.collect(Collectors.toMap(RecallRetrieval::getId, e -> e, (value1, value2) -> value1));
Plugin plugin = pluginOptional.get();
log.info("EmbeddingBasedParser text: {} domain: {} choose plugin: [{} {}]",
text, domainId, plugin.getId(), plugin.getName());
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(queryContext, domainId,
plugin, embeddingRetrievalMap);
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
pluginQuery.setParseInfo(semanticParseInfo);
queryContext.getCandidateQueries().add(pluginQuery);
Set<Long> domainIds = getDomainMatched(queryContext);
String text = queryContext.getRequest().getQueryText();
if (!CollectionUtils.isEmpty(domainIds)) {
for (Long domainId : domainIds) {
List<SchemaElementMatch> schemaElementMatches = getMatchedElements(queryContext, domainId);
String textReplaced = replaceText(text, schemaElementMatches);
List<RecallRetrieval> embeddingRetrievals = recallResult(textReplaced, hasCandidateQuery(queryContext));
Optional<Plugin> pluginOptional = choosePlugin(embeddingRetrievals, domainId);
log.info("domain id :{} embedding result, text:{} embeddingResp:{} ",domainId, textReplaced, embeddingRetrievals);
pluginOptional.ifPresent(plugin -> buildQuery(plugin, embeddingRetrievals, domainId, textReplaced, queryContext, schemaElementMatches));
}
} else {
List<RecallRetrieval> embeddingRetrievals = recallResult(text, hasCandidateQuery(queryContext));
Optional<Plugin> pluginOptional = choosePlugin(embeddingRetrievals, null);
pluginOptional.ifPresent(plugin -> buildQuery(plugin, embeddingRetrievals, null, text, queryContext, Lists.newArrayList()));
}
}
private void buildQuery(Plugin plugin, List<RecallRetrieval> embeddingRetrievals,
Long domainId, String text,
QueryContext queryContext, List<SchemaElementMatch> schemaElementMatches) {
Map<String, RecallRetrieval> embeddingRetrievalMap = embeddingRetrievals.stream()
.collect(Collectors.toMap(RecallRetrieval::getId, e -> e, (value1, value2) -> value1));
log.info("EmbeddingBasedParser text: {} domain: {} choose plugin: [{} {}]",
text, domainId, plugin.getId(), plugin.getName());
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(domainId, plugin, text,
queryContext.getRequest(), embeddingRetrievalMap, schemaElementMatches);
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
pluginQuery.setParseInfo(semanticParseInfo);
queryContext.getCandidateQueries().add(pluginQuery);
}
private Set<Long> getDomainMatched(QueryContext queryContext) {
Long queryDomainId = queryContext.getRequest().getDomainId();
if (queryDomainId != null && queryDomainId > 0) {
@@ -62,51 +77,61 @@ public class EmbeddingBasedParser implements SemanticParser {
return queryContext.getMapInfo().getMatchedDomains();
}
private SemanticParseInfo buildSemanticParseInfo(QueryContext queryContext, Long domainId, Plugin plugin,
Map<String, RecallRetrieval> embeddingRetrievalMap) {
private SemanticParseInfo buildSemanticParseInfo(Long domainId, Plugin plugin, String text, QueryReq queryReq,
Map<String, RecallRetrieval> embeddingRetrievalMap,
List<SchemaElementMatch> schemaElementMatches) {
SchemaElement schemaElement = new SchemaElement();
schemaElement.setDomain(domainId);
schemaElement.setId(domainId);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
semanticParseInfo.setElementMatches(schemaElementMatches);
semanticParseInfo.setDomain(schemaElement);
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
if (Double.parseDouble(embeddingRetrievalMap.get(plugin.getId().toString()).getDistance()) < THRESHOLD) {
semanticParseInfo.setBonus(SatisfactionChecker.BONUS_THRESHOLD);
}
double distance = Double.parseDouble(embeddingRetrievalMap.get(plugin.getId().toString()).getDistance());
double score = text.length() * (1 - distance);
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, plugin);
PluginParseResult pluginParseResult = new PluginParseResult();
pluginParseResult.setPlugin(plugin);
pluginParseResult.setRequest(queryReq);
pluginParseResult.setDistance(distance);
properties.put(Constants.CONTEXT, pluginParseResult);
semanticParseInfo.setProperties(properties);
semanticParseInfo.setElementMatches(schemaMapInfo.getMatchedElements(domainId));
fillSemanticParseInfo(queryContext, semanticParseInfo);
setEntityId(domainId, semanticParseInfo);
semanticParseInfo.setScore(score);
fillSemanticParseInfo(semanticParseInfo);
setEntity(domainId, semanticParseInfo);
return semanticParseInfo;
}
private Optional<Long> getEntityElementId(Long domainId) {
ConfigService configService = ContextUtils.getBean(ConfigService.class);
ChatConfigRich chatConfigRich = configService.getConfigRichInfo(domainId);
EntityRichInfo entityRichInfo = chatConfigRich.getChatDetailRichConfig().getEntity();
if (entityRichInfo != null) {
SchemaElement schemaElement = entityRichInfo.getDimItem();
if (schemaElement != null) {
return Optional.of(schemaElement.getId());
}
private List<SchemaElementMatch> getMatchedElements(QueryContext queryContext, Long domainId) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(domainId);
if (schemaElementMatches == null) {
return Lists.newArrayList();
}
return Optional.empty();
QueryReq queryReq = queryContext.getRequest();
QueryFilters queryFilters = queryReq.getQueryFilters();
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return schemaElementMatches;
}
Map<Long, Object> element = queryFilters.getFilters().stream()
.collect(Collectors.toMap(QueryFilter::getElementID, QueryFilter::getValue, (v1, v2) -> v1));
return schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
.filter(schemaElementMatch ->
!element.containsKey(schemaElementMatch.getElement().getId()) || (
element.containsKey(schemaElementMatch.getElement().getId()) &&
element.get(schemaElementMatch.getElement().getId()).toString()
.equalsIgnoreCase(schemaElementMatch.getWord())
))
.collect(Collectors.toList());
}
private void setEntityId(Long domainId, SemanticParseInfo semanticParseInfo) {
Optional<Long> entityElementIdOptional = getEntityElementId(domainId);
if (entityElementIdOptional.isPresent()) {
Long entityElementId = entityElementIdOptional.get();
for (QueryFilter filter : semanticParseInfo.getDimensionFilters()) {
if (entityElementId.equals(filter.getElementID())) {
String value = String.valueOf(filter.getValue());
if (StringUtils.isNumeric(value)) {
semanticParseInfo.setEntity(Long.parseLong(value));
}
}
}
private void setEntity(Long domainId, SemanticParseInfo semanticParseInfo) {
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = semanticService.getDomainSchema(domainId);
if (domainSchema != null && domainSchema.getEntity() != null) {
semanticParseInfo.setEntity(domainSchema.getEntity());
}
}
@@ -120,6 +145,12 @@ public class EmbeddingBasedParser implements SemanticParser {
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
if (plugin == null) {
continue;
}
if (domainId == null) {
return Optional.of(plugin);
}
if (!CollectionUtils.isEmpty(plugin.getDomainList()) && plugin.getDomainList().contains(domainId)) {
return Optional.of(plugin);
}
@@ -131,7 +162,6 @@ public class EmbeddingBasedParser implements SemanticParser {
try {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
EmbeddingResp embeddingResp = pluginManager.recognize(embeddingText);
log.info("embedding result, text:{} embeddingResp:{}", embeddingText, embeddingResp);
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
if(!CollectionUtils.isEmpty(embeddingRetrievals)){
if (hasCandidateQuery) {
@@ -154,25 +184,38 @@ public class EmbeddingBasedParser implements SemanticParser {
return !CollectionUtils.isEmpty(queryContext.getCandidateQueries());
}
private void fillSemanticParseInfo(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
if (queryContext.getRequest().getQueryFilters() != null) {
semanticParseInfo.getDimensionFilters()
.addAll(queryContext.getRequest().getQueryFilters().getFilters());
private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) {
List<SchemaElementMatch> schemaElementMatches = semanticParseInfo.getElementMatches();
if (!CollectionUtils.isEmpty(schemaElementMatches)) {
schemaElementMatches.stream().filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.forEach(schemaElementMatch -> {
QueryFilter queryFilter = new QueryFilter();
queryFilter.setValue(schemaElementMatch.getWord());
queryFilter.setElementID(schemaElementMatch.getElement().getId());
queryFilter.setName(schemaElementMatch.getElement().getName());
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
semanticParseInfo.getDimensionFilters().add(queryFilter);
});
}
}
protected String replaceText(QueryContext queryContext, Long domainId) {
String text = queryContext.getRequest().getQueryText();
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(domainId);
protected String replaceText(String text, List<SchemaElementMatch> schemaElementMatches) {
if (CollectionUtils.isEmpty(schemaElementMatches)) {
return text;
}
List<SchemaElementMatch> valueSchemaElementMatches = schemaElementMatches.stream()
.filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElementMatches) {
String detectWord = schemaElementMatch.getDetectWord();
if (StringUtils.isBlank(detectWord)) {
continue;
}
text = text.replace(detectWord, "");
}
return text;

View File

@@ -4,14 +4,10 @@ import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.config.ChatConfigRich;
import com.tencent.supersonic.chat.parser.function.DomainResolver;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Component;
import java.util.List;
@@ -28,18 +24,6 @@ public class EmbeddingEntityResolver {
this.configService = configService;
}
public Pair<Long, Long> getDomainEntityId(QueryContext queryCtx, ChatContext chatCtx) {
DomainResolver domainResolver = ComponentFactory.getDomainResolver();
Long domainId = domainResolver.resolve(queryCtx, chatCtx);
ChatConfigRich chatConfigRichResp = configService.getConfigRichInfo(domainId);
SchemaElement schemaElement = chatConfigRichResp.getChatDetailRichConfig().getEntity().getDimItem();
if (schemaElement == null) {
return Pair.of(domainId, null);
}
Long entityId = getEntityValue(domainId, schemaElement.getId(), queryCtx, chatCtx);
return Pair.of(domainId, entityId);
}
private Long getEntityValue(Long domainId, Long entityElementId, QueryContext queryCtx, ChatContext chatCtx) {
Long entityId = null;

View File

@@ -3,9 +3,10 @@ package com.tencent.supersonic.chat.parser.function;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import java.util.List;
public interface DomainResolver {
Long resolve(QueryContext queryContext, ChatContext chatCtx);
Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveDomains);
}

View File

@@ -6,11 +6,11 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.config.FunctionCallConfig;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.chat.config.FunctionCallInfoConfig;
import com.tencent.supersonic.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
@@ -22,12 +22,15 @@ import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpEntity;
@@ -44,9 +47,12 @@ public class FunctionBasedParser implements SemanticParser {
public static final double FUNCTION_BONUS_THRESHOLD = 200;
public static final double SKIP_DSL_LENGTH = 10;
@Override
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
FunctionCallConfig functionCallConfig = ContextUtils.getBean(FunctionCallConfig.class);
FunctionCallInfoConfig functionCallConfig = ContextUtils.getBean(FunctionCallInfoConfig.class);
PluginService pluginService = ContextUtils.getBean(PluginService.class);
String functionUrl = functionCallConfig.getUrl();
@@ -55,39 +61,57 @@ public class FunctionBasedParser implements SemanticParser {
queryCtx.getRequest().getQueryText());
return;
}
DomainResolver domainResolver = ComponentFactory.getDomainResolver();
Long domainId = domainResolver.resolve(queryCtx, chatCtx);
List<String> functionNames = getFunctionNames(domainId);
log.info("domainId:{},functionNames:{}", domainId, functionNames);
if (Objects.isNull(domainId) || domainId <= 0) {
Set<Long> matchedDomains = getMatchDomains(queryCtx);
List<String> functionNames = getFunctionNames(matchedDomains);
log.info("matchedDomains:{},functionNames:{}", matchedDomains, functionNames);
if (CollectionUtils.isEmpty(functionNames) || CollectionUtils.isEmpty(matchedDomains)) {
return;
}
List<PluginParseConfig> functionDOList = getFunctionDO(queryCtx.getRequest().getDomainId());
FunctionReq functionReq = FunctionReq.builder()
.queryText(queryCtx.getRequest().getQueryText())
.functionNames(functionNames).build();
.pluginConfigs(functionDOList).build();
FunctionResp functionResp = requestFunction(functionUrl, functionReq);
log.info("requestFunction result:{}", functionResp.getToolSelection());
if (Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection())) {
if (skipFunction(queryCtx, functionResp)) {
return;
}
PluginParseResult functionCallParseResult = new PluginParseResult();
String toolSelection = functionResp.getToolSelection();
Optional<Plugin> pluginOptional = pluginService.getPluginByName(toolSelection);
if (pluginOptional.isPresent()) {
toolSelection = pluginOptional.get().getType();
functionCallParseResult.setPlugin(pluginOptional.get());
if (!pluginOptional.isPresent()) {
log.info("pluginOptional is not exist:{}, skip the parse", toolSelection);
return;
}
Plugin plugin = pluginOptional.get();
toolSelection = plugin.getType();
functionCallParseResult.setPlugin(plugin);
log.info("QueryManager PluginQueryModes:{}", QueryManager.getPluginQueryModes());
PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(toolSelection);
DomainResolver domainResolver = ComponentFactory.getDomainResolver();
Long domainId = domainResolver.resolve(queryCtx, chatCtx, plugin.getDomainList());
log.info("FunctionBasedParser domainId:{}",domainId);
if ((Objects.isNull(domainId) || domainId <= 0) && !plugin.isContainsAllDomain()) {
log.info("domain is null, skip the parse, select tool: {}", toolSelection);
return;
}
if (!plugin.getDomainList().contains(domainId) && !plugin.isContainsAllDomain()) {
return;
}
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(domainId));
if (Objects.nonNull(domainId) && domainId > 0){
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(domainId));
}
functionCallParseResult.setRequest(queryCtx.getRequest());
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, functionCallParseResult);
parseInfo.setProperties(properties);
parseInfo.setBonus(FUNCTION_BONUS_THRESHOLD);
parseInfo.setScore(FUNCTION_BONUS_THRESHOLD);
parseInfo.setQueryMode(semanticQuery.getQueryMode());
SchemaElement domain = new SchemaElement();
domain.setDomain(domainId);
domain.setId(domainId);
@@ -95,13 +119,61 @@ public class FunctionBasedParser implements SemanticParser {
queryCtx.getCandidateQueries().add(semanticQuery);
}
private List<String> getFunctionNames(Long domainId) {
private Set<Long> getMatchDomains(QueryContext queryCtx) {
Set<Long> result = new HashSet<>();
Long domainId = queryCtx.getRequest().getDomainId();
if (Objects.nonNull(domainId) && domainId > 0) {
result.add(domainId);
return result;
}
return queryCtx.getMapInfo().getMatchedDomains();
}
private boolean skipFunction(QueryContext queryCtx, FunctionResp functionResp) {
if (Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection())) {
return true;
}
String queryText = queryCtx.getRequest().getQueryText();
if (functionResp.getToolSelection().equalsIgnoreCase(DSLQuery.QUERY_MODE)
&& queryText.length() < SKIP_DSL_LENGTH) {
log.info("queryText length is :{}, less than the threshold :{}, skip dsl.", queryText.length(),
SKIP_DSL_LENGTH);
return true;
}
return false;
}
private List<PluginParseConfig> getFunctionDO(Long domainId) {
log.info("user decide domain:{}", domainId);
List<Plugin> plugins = PluginManager.getPlugins();
List<PluginParseConfig> functionDOList = plugins.stream().filter(o -> {
if (o.getParseModeConfig() == null) {
return false;
}
if (!CollectionUtils.isEmpty(o.getDomainList())) {//过滤掉没选主题域的插件
return true;
}
if (domainId == null || domainId <= 0L) {
return true;
} else {
return o.getDomainList().contains(domainId);
}
}).map(o -> {
PluginParseConfig functionCallConfig = JsonUtil.toObject(o.getParseModeConfig(),
PluginParseConfig.class);
return functionCallConfig;
}).collect(Collectors.toList());
log.info("getFunctionDO:{}", JsonUtil.toString(functionDOList));
return functionDOList;
}
private List<String> getFunctionNames(Set<Long> matchedDomains) {
List<Plugin> plugins = PluginManager.getPlugins();
Set<String> functionNames = plugins.stream()
.filter(entry -> ParseMode.FUNCTION_CALL.equals(entry.getParseMode()))
.filter(entry -> {
if (!CollectionUtils.isEmpty(entry.getDomainList())) {
return entry.getDomainList().contains(domainId);
if (!CollectionUtils.isEmpty(entry.getDomainList()) && !CollectionUtils.isEmpty(matchedDomains)) {
return entry.getDomainList().stream().anyMatch(matchedDomains::contains);
}
return true;
}
@@ -118,7 +190,7 @@ public class FunctionBasedParser implements SemanticParser {
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
try {
log.info("requestFunction functionReq:{}", functionReq);
log.info("requestFunction functionReq:{}", JsonUtil.toString(functionReq));
ResponseEntity<FunctionResp> responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
FunctionResp.class);
log.info("requestFunction responseEntity:{},cost:{}", responseEntity,

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.parser.function;
import lombok.Data;
@Data
public class FunctionFiled {
private String type;
private String description;
}

View File

@@ -1,6 +1,8 @@
package com.tencent.supersonic.chat.parser.function;
import java.util.List;
import com.tencent.supersonic.chat.plugin.PluginParseConfig;
import lombok.Builder;
import lombok.Data;
@@ -10,6 +12,6 @@ public class FunctionReq {
private String queryText;
private List<String> functionNames;
private List<PluginParseConfig> pluginConfigs;
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.parser.function;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.request.QueryRequest;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import lombok.extern.slf4j.Slf4j;
@@ -25,10 +25,10 @@ public class HeuristicDomainResolver implements DomainResolver {
Map.Entry<Long, DomainMatchResult> maxDomain = domainTypeMap.entrySet().stream()
.filter(entry -> domainQueryModes.containsKey(entry.getKey()))
.sorted((o1, o2) -> {
int difference = o1.getValue().getCount() - o2.getValue().getCount();
int difference = o2.getValue().getCount() - o1.getValue().getCount();
if (difference == 0) {
return (int) ((o1.getValue().getMaxSimilarity()
- o2.getValue().getMaxSimilarity()) * 100);
return (int) ((o2.getValue().getMaxSimilarity()
- o1.getValue().getMaxSimilarity()) * 100);
}
return difference;
}).findFirst().orElse(null);
@@ -46,7 +46,7 @@ public class HeuristicDomainResolver implements DomainResolver {
* @return false will use context domain, true will use other domain , maybe include context domain
*/
protected static boolean isAllowSwitch(Map<Long, SemanticQuery> domainQueryModes, SchemaMapInfo schemaMap,
ChatContext chatCtx, QueryRequest searchCtx, Long domainId) {
ChatContext chatCtx, QueryReq searchCtx, Long domainId, List<Long> restrictiveDomains) {
if (!Objects.nonNull(domainId) || domainId <= 0) {
return true;
}
@@ -81,6 +81,9 @@ public class HeuristicDomainResolver implements DomainResolver {
}
}
if (CollectionUtils.isNotEmpty(restrictiveDomains) && !restrictiveDomains.contains(domainId)) {
return true;
}
// if context domain not in schemaMap , will switch
if (schemaMap.getMatchedElements(domainId) == null || schemaMap.getMatchedElements(domainId).size() <= 0) {
log.info("domainId not in schemaMap ");
@@ -118,24 +121,36 @@ public class HeuristicDomainResolver implements DomainResolver {
}
public Long resolve(QueryContext queryContext, ChatContext chatCtx) {
public Long resolve(QueryContext queryContext, ChatContext chatCtx, List<Long> restrictiveDomains) {
Long domainId = queryContext.getRequest().getDomainId();
if (Objects.nonNull(domainId) && domainId > 0) {
return domainId;
if (CollectionUtils.isNotEmpty(restrictiveDomains) && restrictiveDomains.contains(domainId)) {
return domainId;
} else {
return null;
}
}
SchemaMapInfo mapInfo = queryContext.getMapInfo();
Set<Long> matchedDomains = mapInfo.getMatchedDomains();
if (CollectionUtils.isNotEmpty(restrictiveDomains)) {
matchedDomains = matchedDomains.stream()
.filter(matchedDomain -> restrictiveDomains.contains(matchedDomain))
.collect(Collectors.toSet());
}
Map<Long, SemanticQuery> domainQueryModes = new HashMap<>();
for (Long matchedDomain : matchedDomains) {
domainQueryModes.put(matchedDomain, null);
}
if(domainQueryModes.size()==1){
return domainQueryModes.keySet().stream().findFirst().get();
}
return resolve(domainQueryModes, queryContext, chatCtx,
queryContext.getMapInfo());
queryContext.getMapInfo(),restrictiveDomains);
}
public Long resolve(Map<Long, SemanticQuery> domainQueryModes, QueryContext queryContext,
ChatContext chatCtx, SchemaMapInfo schemaMap) {
Long selectDomain = selectDomain(domainQueryModes, queryContext.getRequest(), chatCtx, schemaMap);
ChatContext chatCtx, SchemaMapInfo schemaMap, List<Long> restrictiveDomains) {
Long selectDomain = selectDomain(domainQueryModes, queryContext.getRequest(), chatCtx, schemaMap,restrictiveDomains);
if (selectDomain > 0) {
log.info("selectDomain {} ", selectDomain);
return selectDomain;
@@ -144,9 +159,9 @@ public class HeuristicDomainResolver implements DomainResolver {
return selectDomainBySchemaElementCount(domainQueryModes, schemaMap);
}
public Long selectDomain(Map<Long, SemanticQuery> domainQueryModes, QueryRequest queryContext,
public Long selectDomain(Map<Long, SemanticQuery> domainQueryModes, QueryReq queryContext,
ChatContext chatCtx,
SchemaMapInfo schemaMap) {
SchemaMapInfo schemaMap, List<Long> restrictiveDomains) {
// if QueryContext has domainId and in domainQueryModes
if (domainQueryModes.containsKey(queryContext.getDomainId())) {
log.info("selectDomain from QueryContext [{}]", queryContext.getDomainId());
@@ -155,7 +170,7 @@ public class HeuristicDomainResolver implements DomainResolver {
// if ChatContext has domainId and in domainQueryModes
if (chatCtx.getParseInfo().getDomainId() > 0) {
Long domainId = chatCtx.getParseInfo().getDomainId();
if (!isAllowSwitch(domainQueryModes, schemaMap, chatCtx, queryContext, domainId)) {
if (!isAllowSwitch(domainQueryModes, schemaMap, chatCtx, queryContext, domainId,restrictiveDomains)) {
log.info("selectDomain from ChatContext [{}]", domainId);
return domainId;
}
@@ -163,4 +178,4 @@ public class HeuristicDomainResolver implements DomainResolver {
// default 0
return 0L;
}
}
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.chat.parser.function;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
public class Parameters {
//default: object
private String type = "object";
private Map<String, FunctionFiled> properties;
private List<String> required;
}

View File

@@ -0,0 +1,49 @@
package com.tencent.supersonic.chat.parser.llm;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.utils.ChatGptHelper;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class LLMTimeEnhancementParse implements SemanticParser {
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {
log.info("before queryContext:{},chatContext:{}",queryContext,chatContext);
ChatGptHelper chatGptHelper = ContextUtils.getBean(ChatGptHelper.class);
String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText());
try {
if (!queryContext.getCandidateQueries().isEmpty()) {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
DateConf dateInfo = query.getParseInfo().getDateInfo();
JSONObject jsonObject = JSON.parseObject(inferredTime);
if (jsonObject.containsKey("date")){
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(jsonObject.getString("date"));
dateInfo.setEndDate(jsonObject.getString("date"));
query.getParseInfo().setDateInfo(dateInfo);
}else if (jsonObject.containsKey("start")){
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(jsonObject.getString("start"));
dateInfo.setEndDate(jsonObject.getString("end"));
query.getParseInfo().setDateInfo(dateInfo);
}
}
}
}catch (Exception exception){
log.error("{} parse error,this reason is:{}",LLMTimeEnhancementParse.class.getSimpleName(), (Object) exception.getStackTrace());
}
log.info("after queryContext:{},chatContext:{}",queryContext,chatContext);
}
}

View File

@@ -1,22 +1,28 @@
package com.tencent.supersonic.chat.parser.rule;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.AVG;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.COUNT;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.DISTINCT;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.MAX;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.MIN;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.TOPN;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import java.util.*;
import java.util.AbstractMap;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.*;
@Slf4j
public class AggregateTypeParser implements SemanticParser {
@@ -29,7 +35,7 @@ public class AggregateTypeParser implements SemanticParser {
new AbstractMap.SimpleEntry<>(DISTINCT, Pattern.compile("(?i)(uv)")),
new AbstractMap.SimpleEntry<>(COUNT, Pattern.compile("(?i)(总数|pv)")),
new AbstractMap.SimpleEntry<>(NONE, Pattern.compile("(?i)(明细)"))
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,(k1,k2)->k2));
@Override
public void parse(QueryContext queryContext, ChatContext chatContext) {

View File

@@ -1,17 +1,25 @@
package com.tencent.supersonic.chat.parser.rule;
import com.tencent.supersonic.chat.api.component.SemanticParser;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.query.rule.metric.MetricDomainQuery;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.common.util.JsonUtil;
import java.util.*;
import com.tencent.supersonic.chat.query.rule.metric.MetricDomainQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
@@ -23,7 +31,8 @@ public class ContextInheritParser implements SemanticParser {
new AbstractMap.SimpleEntry<>(DIMENSION, Arrays.asList(DIMENSION, VALUE)),
new AbstractMap.SimpleEntry<>(VALUE, Arrays.asList(VALUE, DIMENSION)),
new AbstractMap.SimpleEntry<>(ENTITY, Arrays.asList(ENTITY)),
new AbstractMap.SimpleEntry<>(DOMAIN, Arrays.asList(DOMAIN))
new AbstractMap.SimpleEntry<>(DOMAIN, Arrays.asList(DOMAIN)),
new AbstractMap.SimpleEntry<>(ID, Arrays.asList(ID))
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
@Override
@@ -40,7 +49,9 @@ public class ContextInheritParser implements SemanticParser {
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
SchemaElementType matchType = match.getElement().getType();
// mutual exclusive element types should not be inherited
if (!containsTypes(elementMatches, MUTUAL_EXCLUSIVE_MAP.get(matchType))) {
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(chatContext.getParseInfo().getQueryMode());
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
match.setMode(SchemaElementMatch.MatchMode.INHERIT);
matchesToInherit.add(match);
}
}
@@ -53,22 +64,31 @@ public class ContextInheritParser implements SemanticParser {
}
}
private boolean containsTypes(List<SchemaElementMatch> matches, List<SchemaElementType> types) {
return matches.stream().anyMatch(m -> types.contains(m.getElement().getType()));
private boolean containsTypes(List<SchemaElementMatch> matches, SchemaElementType matchType,
RuleSemanticQuery ruleQuery) {
List<SchemaElementType> types = MUTUAL_EXCLUSIVE_MAP.get(matchType);
return matches.stream().anyMatch(m -> {
SchemaElementType type = m.getElement().getType();
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
&& !(ruleQuery instanceof MetricEntityQuery)) {
return types.contains(type);
}
return type.equals(matchType);
});
}
protected boolean shouldInherit(QueryContext queryContext, ChatContext chatContext) {
if (queryContext.getMapInfo().getMatchedElements(
chatContext.getParseInfo().getDomainId()) == null) {
Long contextDomainId = chatContext.getParseInfo().getDomainId();
if (queryContext.getMapInfo().getMatchedElements(contextDomainId) == null) {
return false;
}
// if candidates have only one MetricDomain mode and context has value filter , count in context
if (queryContext.getCandidateQueries().size() == 1 && (queryContext.getCandidateQueries()
.get(0) instanceof MetricDomainQuery)
&& queryContext.getCandidateQueries().get(0).getParseInfo().getDomainId()
.equals(chatContext.getParseInfo().getDomainId())
&& !CollectionUtils.isEmpty(chatContext.getParseInfo().getDimensionFilters())) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries().stream()
.filter(semanticQuery -> semanticQuery.getParseInfo().getDomainId().equals(contextDomainId)).collect(
Collectors.toList());
if (candidateQueries.size() == 1 && (candidateQueries.get(0) instanceof MetricDomainQuery)) {
return true;
} else {
return queryContext.getCandidateQueries().size() == 0;

View File

@@ -6,7 +6,6 @@ import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import java.util.*;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
@Slf4j

View File

@@ -46,7 +46,7 @@ public class TimeRangeParser implements SemanticParser {
for (SemanticQuery query : queryContext.getCandidateQueries()) {
query.getParseInfo().setDateInfo(dateConf);
}
} else if(QueryManager.containsRuleQuery(chatContext.getParseInfo().getQueryMode())) {
} else if (QueryManager.containsRuleQuery(chatContext.getParseInfo().getQueryMode())) {
RuleSemanticQuery semanticQuery = QueryManager.createRuleQuery(
chatContext.getParseInfo().getQueryMode());
// inherit parse info from context
@@ -64,7 +64,7 @@ public class TimeRangeParser implements SemanticParser {
List<TimeNLP> times = TimeNLPUtil.parse(queryText);
if (times.size() > 0) {
startDate = times.get(0).getTime();
}else {
} else {
return null;
}
@@ -133,7 +133,7 @@ public class TimeRangeParser implements SemanticParser {
info.setPeriod(Constants.DAY);
}
days = days * num;
info.setDateMode(DateConf.DateMode.RECENT_UNITS);
info.setDateMode(DateConf.DateMode.RECENT);
String text = "" + num + zhPeriod;
if (Strings.isNotEmpty(m.group("periodStr"))) {
text = m.group("periodStr");
@@ -175,11 +175,11 @@ public class TimeRangeParser implements SemanticParser {
private DateConf getDateConf(Date startDate, Date endDate) {
if (startDate == null || endDate == null) {
return null;
return null;
}
DateConf info = new DateConf();
info.setDateMode(DateConf.DateMode.BETWEEN_CONTINUOUS);
info.setDateMode(DateConf.DateMode.BETWEEN);
info.setStartDate(DATE_FORMAT.format(startDate));
info.setEndDate(DATE_FORMAT.format(endDate));
return info;

View File

@@ -53,11 +53,21 @@ public class PluginDO {
*/
private String updatedBy;
/**
*
*/
private String parseModeConfig;
/**
*
*/
private String config;
/**
*
*/
private String comment;
/**
*
* @return id
@@ -218,6 +228,22 @@ public class PluginDO {
this.updatedBy = updatedBy == null ? null : updatedBy.trim();
}
/**
*
* @return parse_mode_config
*/
public String getParseModeConfig() {
return parseModeConfig;
}
/**
*
* @param parseModeConfig
*/
public void setParseModeConfig(String parseModeConfig) {
this.parseModeConfig = parseModeConfig == null ? null : parseModeConfig.trim();
}
/**
*
* @return config
@@ -233,4 +259,20 @@ public class PluginDO {
public void setConfig(String config) {
this.config = config == null ? null : config.trim();
}
/**
*
* @return comment
*/
public String getComment() {
return comment;
}
/**
*
* @param comment
*/
public void setComment(String comment) {
this.comment = comment == null ? null : comment.trim();
}
}

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.chat.persistence.mapper;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDO;
import com.tencent.supersonic.chat.persistence.dataobject.PluginDOExample;
import org.apache.ibatis.annotations.Mapper;

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.persistence.repository;
import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.config.ChatConfigFilter;
import com.tencent.supersonic.chat.config.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import java.util.List;

View File

@@ -2,17 +2,17 @@ package com.tencent.supersonic.chat.persistence.repository;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryRequest;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResponse;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
public interface ChatQueryRepository {
PageInfo<QueryResponse> getChatQuery(PageQueryInfoReq pageQueryInfoCommend, long chatId);
PageInfo<QueryResp> getChatQuery(PageQueryInfoReq pageQueryInfoCommend, long chatId);
void createChatQuery(QueryResult queryResult, QueryRequest queryContext, ChatContext chatCtx);
void createChatQuery(QueryResult queryResult, ChatContext chatCtx);
ChatQueryDO getLastChatQuery(long chatId);

View File

@@ -1,9 +1,9 @@
package com.tencent.supersonic.chat.persistence.repository.impl;
import com.tencent.supersonic.chat.config.ChatConfig;
import com.tencent.supersonic.chat.config.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.config.ChatConfigFilterInternal;
import com.tencent.supersonic.chat.config.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatConfigDO;
import com.tencent.supersonic.chat.persistence.repository.ChatConfigRepository;
import com.tencent.supersonic.chat.utils.ChatConfigHelper;

View File

@@ -3,12 +3,12 @@ package com.tencent.supersonic.chat.persistence.repository.impl;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.QueryRequest;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDOExample.Criteria;
import com.tencent.supersonic.chat.api.pojo.response.QueryResponse;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.persistence.mapper.ChatQueryDOMapper;
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
@@ -35,7 +35,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
@Override
public PageInfo<QueryResponse> getChatQuery(PageQueryInfoReq pageQueryInfoCommend, long chatId) {
public PageInfo<QueryResp> getChatQuery(PageQueryInfoReq pageQueryInfoCommend, long chatId) {
ChatQueryDOExample example = new ChatQueryDOExample();
example.setOrderByClause("question_id desc");
Criteria criteria = example.createCriteria();
@@ -46,7 +46,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
pageQueryInfoCommend.getPageSize())
.doSelectPageInfo(() -> chatQueryDOMapper.selectByExampleWithBLOBs(example));
PageInfo<QueryResponse> chatQueryVOPageInfo = PageUtils.pageInfo2PageInfoVo(pageInfo);
PageInfo<QueryResp> chatQueryVOPageInfo = PageUtils.pageInfo2PageInfoVo(pageInfo);
chatQueryVOPageInfo.setList(
pageInfo.getList().stream().map(this::convertTo)
.sorted(Comparator.comparingInt(o -> o.getQuestionId().intValue()))
@@ -54,8 +54,8 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
return chatQueryVOPageInfo;
}
private QueryResponse convertTo(ChatQueryDO chatQueryDO) {
QueryResponse queryResponse = new QueryResponse();
private QueryResp convertTo(ChatQueryDO chatQueryDO) {
QueryResp queryResponse = new QueryResp();
BeanUtils.copyProperties(chatQueryDO, queryResponse);
QueryResult queryResult = JsonUtil.toObject(chatQueryDO.getQueryResult(), QueryResult.class);
queryResult.setQueryId(chatQueryDO.getQuestionId());
@@ -64,16 +64,16 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
@Override
public void createChatQuery(QueryResult queryResult, QueryRequest queryRequest, ChatContext chatCtx) {
public void createChatQuery(QueryResult queryResult, ChatContext chatCtx) {
ChatQueryDO chatQueryDO = new ChatQueryDO();
chatQueryDO.setChatId(Long.valueOf(queryRequest.getChatId()));
chatQueryDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatQueryDO.setCreateTime(new java.util.Date());
chatQueryDO.setUserName(queryRequest.getUser().getName());
chatQueryDO.setUserName(chatCtx.getUser());
chatQueryDO.setQueryState(queryResult.getQueryState().ordinal());
chatQueryDO.setQueryText(queryRequest.getQueryText());
chatQueryDO.setQueryText(chatCtx.getQueryText());
chatQueryDO.setQueryResult(JsonUtil.toString(queryResult));
chatQueryDOMapper.insert(chatQueryDO);
ChatQueryDO lastChatQuery = getLastChatQuery(queryRequest.getChatId());
ChatQueryDO lastChatQuery = getLastChatQuery(chatCtx.getChatId());
Long queryId = lastChatQuery.getQuestionId();
queryResult.setQueryId(queryId);
}

View File

@@ -1,37 +1,58 @@
package com.tencent.supersonic.chat.plugin;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.parser.ParseMode;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@Data
public class Plugin extends RecordInfo {
private Long id;
//plugin type WEB_PAGE WEB_SERVICE
/***
* plugin type WEB_PAGE WEB_SERVICE
*/
private String type;
private List<Long> domainList;
private List<Long> domainList = Lists.newArrayList();
//description, for parsing
/**
* description, for parsing
*/
private String pattern;
//parse
/**
* parse
*/
private ParseMode parseMode;
private String parseModeConfig;
private String name;
//config for different plugin type
/**
* config for different plugin type
*/
private String config;
public List<String> getPatterns() {
return Stream.of(getPattern().split("\\|")).collect(Collectors.toList());
private String comment;
public List<String> getExampleQuestionList() {
if (StringUtils.isNotBlank(parseModeConfig)) {
PluginParseConfig pluginParseConfig = JSONObject.parseObject(parseModeConfig, PluginParseConfig.class);
return pluginParseConfig.getExamples();
}
return Lists.newArrayList();
}
public boolean isContainsAllDomain() {
return CollectionUtils.isNotEmpty(domainList) && domainList.contains(-1L);
}
}

View File

@@ -16,6 +16,7 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.context.event.EventListener;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.*;
@@ -40,7 +41,8 @@ public class PluginManager {
public static List<Plugin> getPlugins() {
PluginService pluginService = ContextUtils.getBean(PluginService.class);
List<Plugin> pluginList = pluginService.getPluginList();
List<Plugin> pluginList = pluginService.getPluginList().stream().filter(plugin ->
CollectionUtils.isNotEmpty(plugin.getDomainList())).collect(Collectors.toList());
pluginList.addAll(internalPluginMap.values());
return new ArrayList<>(pluginList);
}
@@ -48,7 +50,7 @@ public class PluginManager {
@EventListener
public void addPlugin(PluginAddEvent pluginAddEvent) {
Plugin plugin = pluginAddEvent.getPlugin();
if (ParseMode.EMBEDDING_RECALL.equals(plugin.getParseMode())) {
if (CollectionUtils.isNotEmpty(plugin.getExampleQuestionList())) {
requestEmbeddingPluginAdd(convert(Lists.newArrayList(plugin)));
}
}
@@ -57,10 +59,10 @@ public class PluginManager {
public void updatePlugin(PluginUpdateEvent pluginUpdateEvent) {
Plugin oldPlugin = pluginUpdateEvent.getOldPlugin();
Plugin newPlugin = pluginUpdateEvent.getNewPlugin();
if (ParseMode.EMBEDDING_RECALL.equals(oldPlugin.getParseMode())) {
if (CollectionUtils.isNotEmpty(oldPlugin.getExampleQuestionList())) {
requestEmbeddingPluginDelete(getEmbeddingId(Lists.newArrayList(oldPlugin)));
}
if (ParseMode.EMBEDDING_RECALL.equals(newPlugin.getParseMode())) {
if (CollectionUtils.isNotEmpty(newPlugin.getExampleQuestionList())) {
requestEmbeddingPluginAdd(convert(Lists.newArrayList(newPlugin)));
}
}
@@ -68,29 +70,30 @@ public class PluginManager {
@EventListener
public void delPlugin(PluginAddEvent pluginAddEvent) {
Plugin plugin = pluginAddEvent.getPlugin();
if (ParseMode.EMBEDDING_RECALL.equals(plugin.getParseMode())) {
if (CollectionUtils.isNotEmpty(plugin.getExampleQuestionList())) {
requestEmbeddingPluginDelete(getEmbeddingId(Lists.newArrayList(plugin)));
}
}
public void requestEmbeddingPluginDelete(Set<String> ids) {
if(CollectionUtils.isEmpty(ids)){
if (CollectionUtils.isEmpty(ids)) {
return;
}
doRequest(embeddingConfig.getDeletePath(), JSONObject.toJSONString(ids));
}
public void requestEmbeddingPluginAdd(List<Map<String,String>> maps) {
if(CollectionUtils.isEmpty(maps)){
public void requestEmbeddingPluginAdd(List<Map<String, String>> maps) {
if (CollectionUtils.isEmpty(maps)) {
return;
}
doRequest(embeddingConfig.getAddPath(), JSONObject.toJSONString(maps));
doRequest(embeddingConfig.getAddPath(), JSONObject.toJSONString(maps));
}
public void doRequest(String path, String jsonBody) {
String url = embeddingConfig.getUrl()+ path;
if (Strings.isEmpty(embeddingConfig.getUrl())) {
return;
}
String url = embeddingConfig.getUrl() + path;
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
@@ -99,7 +102,8 @@ public class PluginManager {
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
log.info("[embedding] equest body :{}, url:{}", jsonBody, url);
ResponseEntity<String> responseEntity =
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {});
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {
});
log.info("[embedding] result body:{}", responseEntity);
}
@@ -111,7 +115,7 @@ public class PluginManager {
}
public EmbeddingResp recognize(String embeddingText) {
String url = embeddingConfig.getUrl()+ embeddingConfig.getRecognizePath() + "?n_results=" + embeddingConfig.getNResult();
String url = embeddingConfig.getUrl() + embeddingConfig.getRecognizePath() + "?n_results=" + embeddingConfig.getNResult();
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
@@ -121,10 +125,11 @@ public class PluginManager {
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
ResponseEntity<List<EmbeddingResp>> embeddingResponseEntity =
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, new ParameterizedTypeReference<List<EmbeddingResp>>() {});
log.info("[embedding] recognize result body:{}",embeddingResponseEntity);
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, new ParameterizedTypeReference<List<EmbeddingResp>>() {
});
log.info("[embedding] recognize result body:{}", embeddingResponseEntity);
List<EmbeddingResp> embeddingResps = embeddingResponseEntity.getBody();
if(CollectionUtils.isNotEmpty(embeddingResps)){
if (CollectionUtils.isNotEmpty(embeddingResps)) {
for (EmbeddingResp embeddingResp : embeddingResps) {
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
@@ -136,13 +141,13 @@ public class PluginManager {
throw new RuntimeException("get embedding result failed");
}
public List<Map<String, String>> convert(List<Plugin> plugins){
public List<Map<String, String>> convert(List<Plugin> plugins) {
List<Map<String, String>> maps = Lists.newArrayList();
for(Plugin plugin : plugins){
List<String> patterns = plugin.getPatterns();
for (Plugin plugin : plugins) {
List<String> exampleQuestions = plugin.getExampleQuestionList();
int num = 0;
for(String pattern : patterns){
Map<String,String> map = new HashMap<>();
for (String pattern : exampleQuestions) {
Map<String, String> map = new HashMap<>();
map.put("preset_query_id", generateUniqueEmbeddingId(num, plugin.getId()));
map.put("preset_query", pattern);
maps.add(map);
@@ -155,7 +160,7 @@ public class PluginManager {
private Set<String> getEmbeddingId(List<Plugin> plugins) {
Set<String> embeddingIdSet = new HashSet<>();
for (Map<String, String> map : convert(plugins)) {
embeddingIdSet.addAll(map.keySet());
embeddingIdSet.add(map.get("preset_query_id"));
}
return embeddingIdSet;
}

View File

@@ -0,0 +1,21 @@
package com.tencent.supersonic.chat.plugin;
import com.tencent.supersonic.chat.parser.function.Parameters;
import lombok.Data;
import java.io.Serializable;
import java.util.List;
@Data
public class PluginParseConfig implements Serializable {
private String name;
private String description;
public Parameters parameters;
public List<String> examples;
}

View File

@@ -1,11 +1,12 @@
package com.tencent.supersonic.chat.plugin;
import com.tencent.supersonic.chat.api.pojo.request.QueryRequest;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import lombok.Data;
@Data
public class PluginParseResult {
private Plugin plugin;
private QueryRequest request;
private QueryReq request;
private double distance;
}

View File

@@ -5,39 +5,46 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.common.pojo.Constants;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public class HeuristicQuerySelector implements QuerySelector {
@Override
public SemanticQuery select(List<SemanticQuery> candidateQueries) {
double maxScore = 0;
SemanticQuery pickedQuery = null;
if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) {
return candidateQueries.get(0);
}
for (SemanticQuery query : candidateQueries) {
SemanticParseInfo semanticParse = query.getParseInfo();
double score = computeScore(semanticParse);
if (score > maxScore) {
maxScore = score;
pickedQuery = query;
}
log.info("candidate query (domain={}, queryMode={}) with score={}",
semanticParse.getDomainName(), semanticParse.getQueryMode(), score);
}
private static final double MATCH_INHERIT_PENALTY = 0.5;
private static final double MATCH_CURRENT_REWORD = 2;
private static final double CANDIDATE_THRESHOLD = 0.2;
return pickedQuery;
@Override
public List<SemanticQuery> select(List<SemanticQuery> candidateQueries) {
List<SemanticQuery> selectedQueries = new ArrayList<>();
if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) {
selectedQueries.addAll(candidateQueries);
} else {
OptionalDouble maxScoreOp = candidateQueries.stream().mapToDouble(
q -> computeScore(q.getParseInfo())).max();
if (maxScoreOp.isPresent()) {
double maxScore = maxScoreOp.getAsDouble();
candidateQueries.stream().forEach(query -> {
SemanticParseInfo semanticParse = query.getParseInfo();
if ((maxScore - semanticParse.getScore()) / maxScore <= CANDIDATE_THRESHOLD) {
selectedQueries.add(query);
}
log.info("candidate query (domain={}, queryMode={}) with score={}",
semanticParse.getDomainName(), semanticParse.getQueryMode(), semanticParse.getScore());
});
}
}
return selectedQueries;
}
private double computeScore(SemanticParseInfo semanticParse) {
double score = 0;
double totalScore = 0;
Map<SchemaElementType, SchemaElementMatch> maxSimilarityMatch = new HashMap<>();
for (SchemaElementMatch match : semanticParse.getElementMatches()) {
@@ -49,13 +56,19 @@ public class HeuristicQuerySelector implements QuerySelector {
}
for (SchemaElementMatch match : maxSimilarityMatch.values()) {
score +=
Optional.ofNullable(match.getDetectWord()).orElse(Constants.EMPTY).length() * match.getSimilarity();
double matchScore = Optional.ofNullable(match.getDetectWord()).orElse(Constants.EMPTY).length() * match.getSimilarity();
if (match.equals(SchemaElementMatch.MatchMode.INHERIT)) {
matchScore *= MATCH_INHERIT_PENALTY;
} else {
matchScore *= MATCH_CURRENT_REWORD;
}
totalScore += matchScore;
}
// bonus is a special construct to control the final score
score += semanticParse.getBonus();
// original score in parse info acts like an extra bonus
totalScore += semanticParse.getScore();
semanticParse.setScore(totalScore);
return score;
return totalScore;
}
}

View File

@@ -3,6 +3,9 @@ package com.tencent.supersonic.chat.query;
import com.tencent.supersonic.chat.api.component.SemanticQuery;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.query.rule.entity.EntitySemanticQuery;
import com.tencent.supersonic.chat.query.rule.metric.MetricSemanticQuery;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -22,6 +25,14 @@ public class QueryManager {
}
}
public static SemanticQuery createQuery(String queryMode) {
if (containsRuleQuery(queryMode)) {
return createRuleQuery(queryMode);
} else {
return createPluginQuery(queryMode);
}
}
public static RuleSemanticQuery createRuleQuery(String queryMode) {
RuleSemanticQuery semanticQuery = ruleQueryMap.get(queryMode);
if (Objects.isNull(semanticQuery)) {
@@ -45,7 +56,6 @@ public class QueryManager {
throw new RuntimeException("no supported queryMode :" + queryMode);
}
}
public static boolean containsRuleQuery(String queryMode) {
if (queryMode == null) {
return false;
@@ -53,6 +63,27 @@ public class QueryManager {
return ruleQueryMap.containsKey(queryMode);
}
public static boolean isMetricQuery(String queryMode) {
if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) {
return false;
}
return ruleQueryMap.get(queryMode) instanceof MetricSemanticQuery;
}
public static boolean isEntityQuery(String queryMode) {
if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) {
return false;
}
return ruleQueryMap.get(queryMode) instanceof EntitySemanticQuery;
}
public static RuleSemanticQuery getRuleQuery(String queryMode) {
if (queryMode == null) {
return null;
}
return ruleQueryMap.get(queryMode);
}
public static List<RuleSemanticQuery> getRuleQueries() {
return new ArrayList<>(ruleQueryMap.values());
}

View File

@@ -9,5 +9,5 @@ import java.util.List;
**/
public interface QuerySelector {
SemanticQuery select(List<SemanticQuery> candidateQueries);
List<SemanticQuery> select(List<SemanticQuery> candidateQueries);
}

View File

@@ -0,0 +1,37 @@
package com.tencent.supersonic.chat.query.plugin;
import lombok.Data;
@Data
public class ParamOption {
private ParamType paramType;
private OptionType optionType;
private String key;
private String name;
private String keyAlias;
private Long domainId;
private Long elementId;
private Object value;
/**
* CUSTOM: the value is specified by the user
* SEMANTIC: the value of element
* FORWARD: only forward
*/
public enum ParamType {
CUSTOM, SEMANTIC, FORWARD
}
public enum OptionType {
REQUIRED, OPTIONAL
}
}

View File

@@ -1,22 +1,14 @@
package com.tencent.supersonic.chat.query.plugin;
import com.google.common.collect.Lists;
import lombok.Data;
import java.util.HashMap;
import java.util.Map;
import java.util.List;
@Data
public class WebBase {
private String url;
//key, id of schema element
private Map<String, Object> params = new HashMap<>();
//key, value of shcema element
private Map<String, Object> valueParams = new HashMap<>();
//only forward
private Map<String, Object> forwardParam = new HashMap<>();
private List<ParamOption> paramOptions = Lists.newArrayList();
}

View File

@@ -0,0 +1,14 @@
package com.tencent.supersonic.chat.query.plugin;
import com.google.common.collect.Lists;
import lombok.Data;
import java.util.List;
@Data
public class WebBaseResult {
private String url;
private List<ParamOption> params = Lists.newArrayList();
}

View File

@@ -62,8 +62,7 @@ public class DSLQuery extends PluginSemanticQuery {
@Override
public QueryResult execute(User user) {
PluginParseResult functionCallParseResult = (PluginParseResult) parseInfo.getProperties()
.get(Constants.CONTEXT);
PluginParseResult functionCallParseResult =JsonUtil.toObject(JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT)),PluginParseResult.class);
Long domainId = parseInfo.getDomainId();
LLMResp llmResp = requestLLM(functionCallParseResult, domainId);
if (Objects.isNull(llmResp)) {

View File

@@ -1,15 +1,19 @@
package com.tencent.supersonic.chat.query.plugin.webpage;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.*;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.config.ChatConfigRich;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.query.plugin.WebBaseResult;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
@@ -18,10 +22,9 @@ import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.stream.Collectors;
@Slf4j
@Component
@@ -44,13 +47,14 @@ public class WebPageQuery extends PluginSemanticQuery {
QueryResult queryResult = new QueryResult();
queryResult.setQueryMode(QUERY_MODE);
Map<String, Object> properties = parseInfo.getProperties();
Plugin plugin = (Plugin) properties.get(Constants.CONTEXT);
WebPageResponse webPageResponse = buildResponse(plugin);
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT)), PluginParseResult.class);
WebPageResponse webPageResponse = buildResponse(pluginParseResult.getPlugin());
queryResult.setResponse(webPageResponse);
if (parseInfo.getDomainId() != null && parseInfo.getDomainId() > 0
&& parseInfo.getEntity() != null && parseInfo.getEntity() > 0) {
ChatConfigRich chatConfigRichResp = configService.getConfigRichInfo(parseInfo.getDomainId());
updateSemanticParse(chatConfigRichResp, parseInfo.getEntity());
&& parseInfo.getEntity() != null && Objects.nonNull(parseInfo.getEntity().getId())
&& parseInfo.getEntity().getId() > 0) {
ChatConfigRichResp chatConfigRichResp = configService.getConfigRichInfo(parseInfo.getDomainId());
updateSemanticParse(chatConfigRichResp);
EntityInfo entityInfo = ContextUtils.getBean(SemanticService.class).getEntityInfo(parseInfo, user);
queryResult.setEntityInfo(entityInfo);
} else {
@@ -60,8 +64,7 @@ public class WebPageQuery extends PluginSemanticQuery {
return queryResult;
}
private void updateSemanticParse(ChatConfigRich chatConfigRichResp, Long entityId) {
parseInfo.setEntity(entityId);
private void updateSemanticParse(ChatConfigRichResp chatConfigRichResp) {
SchemaElement domain = new SchemaElement();
domain.setId(chatConfigRichResp.getDomainId());
domain.setName(chatConfigRichResp.getDomainName());
@@ -74,35 +77,43 @@ public class WebPageQuery extends PluginSemanticQuery {
webPageResponse.setPluginId(plugin.getId());
webPageResponse.setPluginType(plugin.getType());
WebBase webPage = JsonUtil.toObject(plugin.getConfig(), WebBase.class);
fillWebPage(webPage);
webPageResponse.setWebPage(webPage);
WebBaseResult webBaseResult = buildWebPageResult(webPage);
webPageResponse.setWebPage(webBaseResult);
return webPageResponse;
}
private void fillWebPage(WebBase webPage) {
List<SchemaElementMatch> schemaElementMatchList = parseInfo.getElementMatches();
private WebBaseResult buildWebPageResult(WebBase webPage) {
WebBaseResult webBaseResult = new WebBaseResult();
webBaseResult.setUrl(webPage.getUrl());
Map<String, Object> elementValueMap = getElementMap();
if (!CollectionUtils.isEmpty(webPage.getParamOptions()) && !CollectionUtils.isEmpty(elementValueMap)) {
for (ParamOption paramOption : webPage.getParamOptions()) {
if (!ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType())) {
continue;
}
String elementId = String.valueOf(paramOption.getElementId());
Object elementValue = elementValueMap.get(elementId);
paramOption.setValue(elementValue);
}
}
webBaseResult.setParams(webPage.getParamOptions());
return webBaseResult;
}
private Map<String, Object> getElementMap() {
Map<String, Object> elementValueMap = new HashMap<>();
if (!CollectionUtils.isEmpty(schemaElementMatchList) && !CollectionUtils.isEmpty(webPage.getParams()) ) {
List<SchemaElementMatch> schemaElementMatchList = parseInfo.getElementMatches();
if (!CollectionUtils.isEmpty(schemaElementMatchList)) {
schemaElementMatchList.stream()
.filter(schemaElementMatch ->
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
.sorted(Comparator.comparingDouble(SchemaElementMatch::getSimilarity))
.forEach(schemaElementMatch ->
elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()),
schemaElementMatch.getWord()));
}
if (!CollectionUtils.isEmpty(parseInfo.getDimensionFilters())) {
parseInfo.getDimensionFilters().forEach(
filter -> elementValueMap.put(String.valueOf(filter.getElementID()), filter.getValue())
);
}
Map<String, Object> params = webPage.getParams();
for (Map.Entry<String, Object> entry : params.entrySet()) {
String key = entry.getKey();
String elementId = String.valueOf(entry.getValue());
Object elementValue = elementValueMap.get(elementId);
webPage.getValueParams().put(key, elementValue);
}
return elementValueMap;
}
}

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.query.plugin.webpage;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.query.plugin.WebBaseResult;
import lombok.Data;
import java.util.List;
@@ -15,8 +16,8 @@ public class WebPageResponse {
private String description;
private WebBase webPage;
private WebBaseResult webPage;
private List<WebBase> moreWebPage;
private List<WebBaseResult> moreWebPage;
}

View File

@@ -1,22 +1,28 @@
package com.tencent.supersonic.chat.query.plugin.webservice;
import com.alibaba.fastjson.JSON;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginParseResult;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.*;
import java.net.URI;
import java.util.List;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.springframework.http.*;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.util.HashMap;
import java.util.Map;
@@ -27,7 +33,7 @@ public class WebServiceQuery extends PluginSemanticQuery {
public static String QUERY_MODE = "WEB_SERVICE";
private S2ThreadContext s2ThreadContext;
private RestTemplate restTemplate;
public WebServiceQuery() {
QueryManager.register(this);
@@ -43,13 +49,9 @@ public class WebServiceQuery extends PluginSemanticQuery {
QueryResult queryResult = new QueryResult();
queryResult.setQueryMode(QUERY_MODE);
Map<String, Object> properties = parseInfo.getProperties();
PluginParseResult pluginParseResult = (PluginParseResult) properties.get(Constants.CONTEXT);
PluginParseResult pluginParseResult =JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT)),PluginParseResult.class);
WebServiceResponse webServiceResponse = buildResponse(pluginParseResult);
Object object = webServiceResponse.getResult();
Map<String,Object> data=JsonUtil.toMap(JsonUtil.toString(object),String.class,Object.class);
queryResult.setQueryResults((List<Map<String, Object>>) data.get("resultList"));
queryResult.setQueryColumns((List<QueryColumn>) data.get("columns"));
//queryResult.setResponse(webServiceResponse);
queryResult.setResponse(webServiceResponse);
queryResult.setQueryState(QueryState.SUCCESS);
parseInfo.setProperties(null);
return queryResult;
@@ -60,22 +62,22 @@ public class WebServiceQuery extends PluginSemanticQuery {
Plugin plugin = pluginParseResult.getPlugin();
WebBase webBase = JsonUtil.toObject(plugin.getConfig(), WebBase.class);
webServiceResponse.setWebBase(webBase);
//http todo
s2ThreadContext = ContextUtils.getBean(S2ThreadContext.class);
String authHeader = s2ThreadContext.get().getToken();
log.info("authHeader:{}", authHeader);
List<ParamOption> paramOptions = webBase.getParamOptions();
Map<String, Object> params = new HashMap<>();
paramOptions.forEach(o -> params.put(o.getKey(), o.getValue()));
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> entity = new HttpEntity<>(JSON.toJSONString(params), headers);
URI requestUrl = UriComponentsBuilder.fromHttpUrl(webBase.getUrl()).build().encode().toUri();
ResponseEntity responseEntity = null;
Object objectResponse = null;
restTemplate = ContextUtils.getBean(RestTemplate.class);
try {
Map<String, String> headers = new HashMap<>();
headers.put("Authorization", authHeader);
Map<String, String> params = new HashMap<>();
params.put("queryText", pluginParseResult.getRequest().getQueryText());
HttpClientResult httpClientResult = HttpClientUtils.doGet(webBase.getUrl(), headers, params);
log.info(" response body:{}", httpClientResult.getContent());
Map<String, Object> result = JsonUtil.toMap(JsonUtil.toString(httpClientResult.getContent()), String.class, Object.class);
log.info(" result:{}", result);
Map<String, Object> data = JsonUtil.toMap(JsonUtil.toString(result.get("data")), String.class, Object.class);
log.info(" data:{}", data);
webServiceResponse.setResult(data);
responseEntity = restTemplate.exchange(requestUrl, HttpMethod.POST, entity, Object.class);
objectResponse = responseEntity.getBody();
log.info("objectResponse:{}", objectResponse);
Map<String, Object> response = JsonUtil.objectToMap(objectResponse);
webServiceResponse.setResult(response);
} catch (Exception e) {
log.info("Exception:{}", e.getMessage());
}

View File

@@ -26,8 +26,6 @@ public class QueryMatcher {
private boolean supportOrderBy;
private List<AggregateTypeEnum> orderByTypes = Arrays.asList(AggregateTypeEnum.MAX, AggregateTypeEnum.MIN,
AggregateTypeEnum.TOPN);
private Long FREQUENCY = 9999999L;
private double SIMILARITY = 1.0;
public QueryMatcher() {
for (SchemaElementType type : SchemaElementType.values()) {
@@ -52,11 +50,10 @@ public class QueryMatcher {
* @return a list of all matched schema elements,
* empty list if no matches can be found
*/
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches, QueryFilters queryFilters) {
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches) {
List<SchemaElementMatch> elementMatches = new ArrayList<>();
List<SchemaElementMatch> schemaElementMatchWithQueryFilter = addSchemaElementMatch(candidateElementMatches, queryFilters);
HashMap<SchemaElementType, Integer> schemaElementTypeCount = new HashMap<>();
for (SchemaElementMatch schemaElementMatch : schemaElementMatchWithQueryFilter) {
for (SchemaElementMatch schemaElementMatch : candidateElementMatches) {
SchemaElementType schemaElementType = schemaElementMatch.getElement().getType();
if (schemaElementTypeCount.containsKey(schemaElementType)) {
schemaElementTypeCount.put(schemaElementType, schemaElementTypeCount.get(schemaElementType) + 1);
@@ -75,7 +72,7 @@ public class QueryMatcher {
}
// add element match if its element type is not declared as unused
for (SchemaElementMatch elementMatch : schemaElementMatchWithQueryFilter) {
for (SchemaElementMatch elementMatch : candidateElementMatches) {
QueryMatchOption elementOption = elementOptionMap.get(elementMatch.getElement().getType());
if (Objects.nonNull(elementOption) && !elementOption.getSchemaElementOption()
.equals(QueryMatchOption.OptionType.UNUSED)) {
@@ -86,32 +83,6 @@ public class QueryMatcher {
return elementMatches;
}
private List<SchemaElementMatch> addSchemaElementMatch(List<SchemaElementMatch> candidateElementMatches, QueryFilters queryFilter) {
List<SchemaElementMatch> schemaElementMatchWithQueryFilter = new ArrayList<>(candidateElementMatches);
if (queryFilter == null || CollectionUtils.isEmpty(queryFilter.getFilters())) {
return schemaElementMatchWithQueryFilter;
}
QueryMatchOption queryMatchOption = elementOptionMap.get(SchemaElementType.VALUE);
if (queryMatchOption != null && QueryMatchOption.OptionType.REQUIRED.equals(queryMatchOption.getSchemaElementOption())) {
for (QueryFilter filter : queryFilter.getFilters()) {
SchemaElement element = SchemaElement.builder()
.id(filter.getElementID())
.name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE)
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
.frequency(FREQUENCY)
.word(String.valueOf(filter.getValue()))
.similarity(SIMILARITY)
.detectWord(Constants.EMPTY)
.build();
schemaElementMatchWithQueryFilter.add(schemaElementMatch);
}
}
return schemaElementMatchWithQueryFilter;
}
private int getCount(HashMap<SchemaElementType, Integer> schemaElementTypeCount,
SchemaElementType schemaElementType) {
if (schemaElementTypeCount.containsKey(schemaElementType)) {

View File

@@ -9,9 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.config.ChatConfigRich;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
@@ -24,6 +22,7 @@ import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import lombok.ToString;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import java.io.Serializable;
import java.util.*;
@@ -42,33 +41,51 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
QueryContext queryCtx) {
return queryMatcher.match(candidateElementMatches, queryCtx.getRequest().getQueryFilters());
return queryMatcher.match(candidateElementMatches);
}
public void fillParseInfo(Long domainId, ChatContext chatContext){
public void fillParseInfo(Long domainId, ChatContext chatContext) {
parseInfo.setQueryMode(getQueryMode());
ConfigService configService = ContextUtils.getBean(ConfigService.class);
ChatConfigRich chatConfig = configService.getConfigRichInfo(domainId);
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = schemaService.getDomainSchema(domainId);
fillSchemaElement(parseInfo, domainSchema, chatConfig);
fillSchemaElement(parseInfo, domainSchema);
// inherit date info from context
if (parseInfo.getDateInfo() == null && chatContext.getParseInfo().getDateInfo() != null) {
if (parseInfo.getDateInfo() == null && chatContext.getParseInfo().getDateInfo() != null
&& isSameQueryMode(getQueryMode(), chatContext.getParseInfo().getQueryMode())) {
log.info("inherit date info from context");
parseInfo.setDateInfo(chatContext.getParseInfo().getDateInfo());
}
}
private void fillSchemaElement(SemanticParseInfo parseInfo, DomainSchema domainSchema, ChatConfigRich chaConfigRich) {
public boolean isSameQueryMode(String queryModeQuery, String queryModeChat) {
if (Strings.isNotEmpty(queryModeQuery) && Strings.isNotEmpty(queryModeChat)) {
return QueryManager.isEntityQuery(queryModeQuery) && QueryManager.isEntityQuery(queryModeChat)
|| QueryManager.isMetricQuery(queryModeQuery) && QueryManager.isMetricQuery(queryModeChat);
}
return true;
}
private void fillSchemaElement(SemanticParseInfo parseInfo, DomainSchema domainSchema) {
parseInfo.setDomain(domainSchema.getDomain());
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
Map<Long, List<SchemaElementMatch>> id2Values = new HashMap<>();
for (SchemaElementMatch schemaMatch : parseInfo.getElementMatches()) {
SchemaElement element = schemaMatch.getElement();
switch (element.getType()) {
case ID:
SchemaElement entityElement = domainSchema.getElement(SchemaElementType.ENTITY, element.getId());
if (entityElement != null) {
if (id2Values.containsKey(element.getId())) {
id2Values.get(element.getId()).add(schemaMatch);
} else {
id2Values.put(element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
}
}
break;
case VALUE:
SchemaElement dimElement = domainSchema.getElement(SchemaElementType.DIMENSION, element.getId());
if (dimElement != null) {
@@ -85,10 +102,41 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
case METRIC:
parseInfo.getMetrics().add(element);
break;
case ENTITY:
parseInfo.setEntity(element);
break;
default:
}
}
if (!id2Values.isEmpty()) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : id2Values.entrySet()) {
SchemaElement entity = domainSchema.getElement(SchemaElementType.ENTITY, entry.getKey());
if (entry.getValue().size() == 1) {
SchemaElementMatch schemaMatch = entry.getValue().get(0);
QueryFilter dimensionFilter = new QueryFilter();
dimensionFilter.setValue(schemaMatch.getWord());
dimensionFilter.setBizName(entity.getBizName());
dimensionFilter.setName(entity.getName());
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
dimensionFilter.setElementID(schemaMatch.getElement().getId());
parseInfo.getDimensionFilters().add(dimensionFilter);
parseInfo.setEntity(domainSchema.getEntity());
} else {
QueryFilter dimensionFilter = new QueryFilter();
List<String> vals = new ArrayList<>();
entry.getValue().stream().forEach(i -> vals.add(i.getWord()));
dimensionFilter.setValue(vals);
dimensionFilter.setBizName(entity.getBizName());
dimensionFilter.setName(entity.getName());
dimensionFilter.setOperator(FilterOperatorEnum.IN);
dimensionFilter.setElementID(entry.getKey());
parseInfo.getDimensionFilters().add(dimensionFilter);
}
}
}
if (!dim2Values.isEmpty()) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dim2Values.entrySet()) {
SchemaElement dimension = domainSchema.getElement(SchemaElementType.DIMENSION, entry.getKey());
@@ -102,7 +150,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
dimensionFilter.setElementID(schemaMatch.getElement().getId());
parseInfo.getDimensionFilters().add(dimensionFilter);
setEntityId(schemaMatch.getWord(), chaConfigRich, parseInfo);
parseInfo.setEntity(domainSchema.getEntity());
} else {
QueryFilter dimensionFilter = new QueryFilter();
List<String> vals = new ArrayList<>();
@@ -118,16 +166,6 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
}
}
public void setEntityId(String value, ChatConfigRich chaConfigRichDesc,
SemanticParseInfo semanticParseInfo) {
if (chaConfigRichDesc != null && chaConfigRichDesc.getChatDetailRichConfig() != null
&& chaConfigRichDesc.getChatDetailRichConfig().getEntity() != null) {
SchemaElement dimSchemaResp = chaConfigRichDesc.getChatDetailRichConfig().getEntity().getDimItem();
if (Objects.nonNull(dimSchemaResp) && StringUtils.isNumeric(value)) {
semanticParseInfo.setEntity(Long.valueOf(value));
}
}
}
@Override
public QueryResult execute(User user) {
@@ -202,12 +240,13 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable {
return parseInfo;
}
@Override
public void setParseInfo(SemanticParseInfo parseInfo) {
this.parseInfo = parseInfo;
}
public static List<RuleSemanticQuery> resolve(List<SchemaElementMatch> candidateElementMatches,
QueryContext queryContext) {
QueryContext queryContext) {
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
List<SchemaElementMatch> matches = semanticQuery.match(candidateElementMatches, queryContext);

View File

@@ -14,7 +14,7 @@ public class EntityDetailQuery extends EntitySemanticQuery {
public EntityDetailQuery() {
super();
queryMatcher.addOption(DIMENSION, REQUIRED, AT_LEAST, 1)
.addOption(VALUE, REQUIRED, AT_LEAST, 1);
.addOption(ID, REQUIRED, AT_LEAST, 1);
}
@Override

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.query.rule.entity;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.*;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
@@ -16,7 +16,8 @@ public class EntityFilterQuery extends EntityListQuery {
public EntityFilterQuery() {
super();
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
queryMatcher.addOption(VALUE, OPTIONAL, AT_LEAST, 0);
queryMatcher.addOption(ID, OPTIONAL, AT_LEAST, 0);
}
@Override

View File

@@ -1,37 +1,43 @@
package com.tencent.supersonic.chat.query.rule.entity;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.DomainSchema;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.config.ChatConfigRich;
import com.tencent.supersonic.chat.config.ChatDefaultRichConfig;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.Order;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.Set;
public abstract class EntityListQuery extends EntitySemanticQuery{
public abstract class EntityListQuery extends EntitySemanticQuery {
@Override
public void fillParseInfo(Long domainId, ChatContext chatContext){
public void fillParseInfo(Long domainId, ChatContext chatContext) {
super.fillParseInfo(domainId, chatContext);
this.addEntityDetailAndOrderByMetric(parseInfo);
}
private void addEntityDetailAndOrderByMetric(SemanticParseInfo parseInfo) {
if (parseInfo.getDomainId() > 0L) {
Long domainId = parseInfo.getDomainId();
if (Objects.nonNull(domainId) && domainId > 0L) {
ConfigService configService = ContextUtils.getBean(ConfigService.class);
ChatConfigRich chaConfigRichDesc = configService.getConfigRichInfo(
parseInfo.getDomainId());
ChatConfigRichResp chaConfigRichDesc = configService.getConfigRichInfo(parseInfo.getDomainId());
SemanticService schemaService = ContextUtils.getBean(SemanticService.class);
DomainSchema domainSchema = schemaService.getDomainSchema(domainId);
if (chaConfigRichDesc != null && chaConfigRichDesc.getChatDetailRichConfig() != null
&& chaConfigRichDesc.getChatDetailRichConfig().getEntity() != null) {
&& Objects.nonNull(domainSchema) && Objects.nonNull(domainSchema.getEntity())) {
Set<SchemaElement> dimensions = new LinkedHashSet();
Set<SchemaElement> metrics = new LinkedHashSet();
Set<Order> orders = new LinkedHashSet();
ChatDefaultRichConfig chatDefaultConfig = chaConfigRichDesc.getChatDetailRichConfig().getChatDefaultConfig();
ChatDefaultRichConfigResp chatDefaultConfig = chaConfigRichDesc.getChatDetailRichConfig().getChatDefaultConfig();
if (chatDefaultConfig != null) {
chatDefaultConfig.getMetrics().stream()
.forEach(metric -> {

View File

@@ -4,9 +4,9 @@ import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.config.ChatConfigResp;
import com.tencent.supersonic.chat.config.ChatConfigRich;
import com.tencent.supersonic.chat.config.ChatDefaultRichConfig;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.common.pojo.DateConf;
@@ -85,8 +85,8 @@ public abstract class EntitySemanticQuery extends RuleSemanticQuery {
parseInfo.setLimit(ENTITY_MAX_RESULTS);
if (parseInfo.getDateInfo() == null) {
ConfigService configService = ContextUtils.getBean(ConfigService.class);
ChatConfigRich chatConfig = configService.getConfigRichInfo(parseInfo.getDomainId());
ChatDefaultRichConfig defaultConfig = chatConfig.getChatDetailRichConfig().getChatDefaultConfig();
ChatConfigRichResp chatConfig = configService.getConfigRichInfo(parseInfo.getDomainId());
ChatDefaultRichConfigResp defaultConfig = chatConfig.getChatDetailRichConfig().getChatDefaultConfig();
int unit = 1;
if (Objects.nonNull(defaultConfig) && Objects.nonNull(defaultConfig.getUnit())) {
@@ -94,7 +94,7 @@ public abstract class EntitySemanticQuery extends RuleSemanticQuery {
}
String date = LocalDate.now().plusDays(-unit).toString();
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateConf.DateMode.BETWEEN_CONTINUOUS);
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
dateInfo.setStartDate(date);
dateInfo.setEndDate(date);

View File

@@ -1,25 +0,0 @@
package com.tencent.supersonic.chat.query.rule.entity;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import org.springframework.stereotype.Component;
@Component
public class EntityTopNQuery extends EntityListQuery {
public static final String QUERY_MODE = "ENTITY_LIST_TOPN";
public EntityTopNQuery() {
super();
queryMatcher.addOption(METRIC, REQUIRED, AT_LEAST, 1)
.setSupportOrderBy(true);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
}

View File

@@ -1,17 +1,12 @@
package com.tencent.supersonic.chat.query.rule.metric;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.response.AggregateInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import java.util.Objects;
import org.springframework.stereotype.Component;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.DOMAIN;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.*;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_MOST;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import org.springframework.stereotype.Component;
@Component
public class MetricDomainQuery extends MetricSemanticQuery {
@@ -31,14 +26,7 @@ public class MetricDomainQuery extends MetricSemanticQuery {
@Override
public QueryResult execute(User user) {
QueryResult queryResult = super.execute(user);
if (!Objects.isNull(queryResult)) {
QueryResultWithSchemaResp queryResp = new QueryResultWithSchemaResp();
queryResp.setColumns(queryResult.getQueryColumns());
queryResp.setResultList(queryResult.getQueryResults());
AggregateInfo aggregateInfo = ContextUtils.getBean(SemanticService.class)
.getAggregateInfo(user, parseInfo, queryResp);
queryResult.setAggregateInfo(aggregateInfo);
}
fillAggregateInfo(user, queryResult);
return queryResult;
}

View File

@@ -0,0 +1,92 @@
package com.tencent.supersonic.chat.query.rule.metric;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.semantic.api.query.pojo.Filter;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
@Slf4j
@Component
public class MetricEntityQuery extends MetricSemanticQuery {
public static final String QUERY_MODE = "METRIC_ENTITY";
public MetricEntityQuery() {
super();
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1)
.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
}
@Override
public String getQueryMode() {
return QUERY_MODE;
}
@Override
public QueryResult execute(User user) {
if (!isMultiStructQuery()) {
QueryResult queryResult = super.execute(user);
fillAggregateInfo(user, queryResult);
return queryResult;
}
return super.multiStructExecute(user);
}
protected boolean isMultiStructQuery() {
Set<String> filterBizName = new HashSet<>();
parseInfo.getDimensionFilters().stream()
.filter(filter -> filter.getElementID() != null)
.forEach(filter -> filterBizName.add(filter.getBizName()));
return filterBizName.size() > 1;
}
@Override
protected QueryStructReq convertQueryStruct() {
QueryStructReq queryStructReq = super.convertQueryStruct();
addDimension(queryStructReq, true);
return queryStructReq;
}
@Override
protected QueryMultiStructReq convertQueryMultiStruct() {
QueryMultiStructReq queryMultiStructReq = super.convertQueryMultiStruct();
for (QueryStructReq queryStructReq : queryMultiStructReq.getQueryStructReqs()) {
addDimension(queryStructReq, false);
}
return queryMultiStructReq;
}
private void addDimension(QueryStructReq queryStructReq, boolean onlyOperateInFilter) {
if (!queryStructReq.getDimensionFilters().isEmpty()) {
List<String> dimensions = queryStructReq.getGroups();
log.info("addDimension before [{}]", queryStructReq.getGroups());
List<Filter> filters = new ArrayList<>(queryStructReq.getDimensionFilters());
if (onlyOperateInFilter) {
filters = filters.stream().filter(filter
-> filter.getOperator().equals(FilterOperatorEnum.IN)).collect(Collectors.toList());
}
filters.forEach(d -> {
if (!dimensions.contains(d.getBizName())) {
dimensions.add(d.getBizName());
}});
queryStructReq.setGroups(dimensions);
log.info("addDimension after [{}]", queryStructReq.getGroups());
}
}
}

View File

@@ -1,23 +1,22 @@
package com.tencent.supersonic.chat.query.rule.metric;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.response.AggregateInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum;
import com.tencent.supersonic.semantic.api.query.pojo.Filter;
import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import java.util.*;
import java.util.stream.Collectors;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.*;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
@Slf4j
@Component
@@ -27,8 +26,7 @@ public class MetricFilterQuery extends MetricSemanticQuery {
public MetricFilterQuery() {
super();
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1)
.addOption(ENTITY, OPTIONAL, AT_MOST, 1);
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
}
@Override
@@ -40,14 +38,7 @@ public class MetricFilterQuery extends MetricSemanticQuery {
public QueryResult execute(User user) {
if (!isMultiStructQuery()) {
QueryResult queryResult = super.execute(user);
if (Objects.nonNull(queryResult)) {
QueryResultWithSchemaResp queryResp = new QueryResultWithSchemaResp();
queryResp.setColumns(queryResult.getQueryColumns());
queryResp.setResultList(queryResult.getQueryResults());
AggregateInfo aggregateInfo = ContextUtils.getBean(SemanticService.class)
.getAggregateInfo(user,parseInfo,queryResp);
queryResult.setAggregateInfo(aggregateInfo);
}
fillAggregateInfo(user, queryResult);
return queryResult;
}
return super.multiStructExecute(user);

View File

@@ -1,27 +1,32 @@
package com.tencent.supersonic.chat.query.rule.metric;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.config.ChatConfigResp;
import com.tencent.supersonic.chat.config.ChatConfigRich;
import com.tencent.supersonic.chat.config.ChatDefaultRichConfig;
import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq;
import com.tencent.supersonic.chat.api.pojo.response.AggregateInfo;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatDefaultRichConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.chat.service.ConfigService;
import com.tencent.supersonic.chat.service.SemanticService;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.util.ContextUtils;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import java.time.LocalDate;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.METRIC;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@Slf4j
public abstract class MetricSemanticQuery extends RuleSemanticQuery {
@@ -78,23 +83,30 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
}
@Override
public void fillParseInfo(Long domainId, ChatContext chatContext){
public void fillParseInfo(Long domainId, ChatContext chatContext) {
super.fillParseInfo(domainId, chatContext);
parseInfo.setLimit(METRIC_MAX_RESULTS);
if (parseInfo.getDateInfo() == null) {
ConfigService configService = ContextUtils.getBean(ConfigService.class);
ChatConfigRich chatConfig = configService.getConfigRichInfo(parseInfo.getDomainId());
ChatDefaultRichConfig defaultConfig = chatConfig.getChatAggRichConfig().getChatDefaultConfig();
ChatConfigRichResp chatConfig = configService.getConfigRichInfo(parseInfo.getDomainId());
ChatDefaultRichConfigResp defaultConfig = chatConfig.getChatAggRichConfig().getChatDefaultConfig();
DateConf dateInfo = new DateConf();
int unit = 1;
if (Objects.nonNull(defaultConfig) && Objects.nonNull(defaultConfig.getUnit())) {
unit = defaultConfig.getUnit();
}
String startDate = LocalDate.now().plusDays(-unit).toString();
String endDate = LocalDate.now().plusDays(-1).toString();
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(DateConf.DateMode.BETWEEN_CONTINUOUS);
String endDate = startDate;
if (ChatDefaultConfigReq.TimeMode.LAST.equals(defaultConfig.getTimeMode())) {
dateInfo.setDateMode(DateConf.DateMode.BETWEEN);
} else if (ChatDefaultConfigReq.TimeMode.RECENT.equals(defaultConfig.getTimeMode())) {
dateInfo.setDateMode(DateConf.DateMode.RECENT);
endDate = LocalDate.now().plusDays(-1).toString();
}
dateInfo.setUnit(unit);
dateInfo.setPeriod(defaultConfig.getPeriod());
dateInfo.setStartDate(startDate);
dateInfo.setEndDate(endDate);
@@ -102,4 +114,15 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
}
}
public void fillAggregateInfo(User user, QueryResult queryResult) {
if (Objects.nonNull(queryResult)) {
QueryResultWithSchemaResp queryResp = new QueryResultWithSchemaResp();
queryResp.setColumns(queryResult.getQueryColumns());
queryResp.setResultList(queryResult.getQueryResults());
AggregateInfo aggregateInfo = ContextUtils.getBean(SemanticService.class)
.getAggregateInfo(user, parseInfo, queryResp);
queryResult.setAggregateInfo(aggregateInfo);
}
}
}

View File

@@ -54,7 +54,7 @@ public class MetricTopNQuery extends MetricSemanticQuery {
super.fillParseInfo(domainId, chatContext);
parseInfo.setLimit(ORDERBY_MAX_RESULTS);
parseInfo.setBonus(2.0);
parseInfo.setScore(2.0);
parseInfo.setAggType(AggregateTypeEnum.SUM);
SchemaElement metric = parseInfo.getMetrics().iterator().next();

View File

@@ -4,7 +4,11 @@ import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.component.SemanticLayer;
import com.tencent.supersonic.chat.config.*;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq;
import com.tencent.supersonic.semantic.api.model.request.PageMetricReq;
@@ -65,12 +69,12 @@ public class ChatConfigController {
@GetMapping("/richDesc/{domainId}")
public ChatConfigRich getDomainExtendRichInfo(@PathVariable("domainId") Long domainId) {
public ChatConfigRichResp getDomainExtendRichInfo(@PathVariable("domainId") Long domainId) {
return configService.getConfigRichInfo(domainId);
}
@GetMapping("/richDesc/all")
public List<ChatConfigRich> getAllChatRichConfig() {
public List<ChatConfigRichResp> getAllChatRichConfig() {
return configService.getAllChatRichConfig();
}

View File

@@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.rest;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResponse;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.service.ChatService;
import java.util.List;
@@ -68,10 +68,10 @@ public class ChatController {
}
@PostMapping("/pageQueryInfo")
public PageInfo<QueryResponse> pageQueryInfo(@RequestBody PageQueryInfoReq pageQueryInfoCommand,
@RequestParam(value = "chatId") long chatId,
HttpServletRequest request,
HttpServletResponse response) {
public PageInfo<QueryResp> pageQueryInfo(@RequestBody PageQueryInfoReq pageQueryInfoCommand,
@RequestParam(value = "chatId") long chatId,
HttpServletRequest request,
HttpServletResponse response) {
pageQueryInfoCommand.setUserName(UserHolder.findUser(request, response).getName());
return chatService.queryInfo(pageQueryInfoCommand, chatId);
}

View File

@@ -2,8 +2,9 @@ package com.tencent.supersonic.chat.rest;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.QueryRequest;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataRequest;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.chat.service.SearchService;
import javax.servlet.http.HttpServletRequest;
@@ -31,28 +32,42 @@ public class ChatQueryController {
@PostMapping("search")
public Object search(@RequestBody QueryRequest queryCtx, HttpServletRequest request,
public Object search(@RequestBody QueryReq queryCtx, HttpServletRequest request,
HttpServletResponse response) {
queryCtx.setUser(UserHolder.findUser(request, response));
return searchService.search(queryCtx);
}
@PostMapping("query")
public Object query(@RequestBody QueryRequest queryCtx, HttpServletRequest request, HttpServletResponse response)
public Object query(@RequestBody QueryReq queryCtx, HttpServletRequest request, HttpServletResponse response)
throws Exception {
queryCtx.setUser(UserHolder.findUser(request, response));
return queryService.executeQuery(queryCtx);
}
@PostMapping("parse")
public Object parse(@RequestBody QueryReq queryCtx, HttpServletRequest request, HttpServletResponse response)
throws Exception {
queryCtx.setUser(UserHolder.findUser(request, response));
return queryService.performParsing(queryCtx);
}
@PostMapping("execute")
public Object execute(@RequestBody ExecuteQueryReq queryCtx, HttpServletRequest request, HttpServletResponse response)
throws Exception {
queryCtx.setUser(UserHolder.findUser(request, response));
return queryService.performExecution(queryCtx);
}
@PostMapping("queryContext")
public Object queryContext(@RequestBody QueryRequest queryCtx, HttpServletRequest request,
public Object queryContext(@RequestBody QueryReq queryCtx, HttpServletRequest request,
HttpServletResponse response) throws Exception {
queryCtx.setUser(UserHolder.findUser(request, response));
return queryService.queryContext(queryCtx);
}
@PostMapping("queryData")
public Object queryData(@RequestBody QueryDataRequest queryData, HttpServletRequest request, HttpServletResponse response)
public Object queryData(@RequestBody QueryDataReq queryData, HttpServletRequest request, HttpServletResponse response)
throws Exception {
return queryService.executeDirectQuery(queryData, UserHolder.findUser(request, response));
}

View File

@@ -50,7 +50,7 @@ public class PluginController {
}
@PostMapping("/query")
List<Plugin> query(PluginQueryReq pluginQueryReq) {
List<Plugin> query(@RequestBody PluginQueryReq pluginQueryReq) {
return pluginService.queryWithAuthCheck(pluginQueryReq);
}

View File

@@ -1,9 +1,9 @@
package com.tencent.supersonic.chat.rest;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.QueryRequest;
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestion;
import com.tencent.supersonic.chat.api.pojo.response.RecommendResponse;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
import com.tencent.supersonic.chat.service.RecommendService;
import javax.servlet.http.HttpServletRequest;
@@ -25,31 +25,31 @@ public class RecommendController {
private RecommendService recommendService;
@GetMapping("recommend/{domainId}")
public RecommendResponse recommend(@PathVariable("domainId") Long domainId,
@RequestParam(value = "limit", required = false) Long limit,
HttpServletRequest request,
HttpServletResponse response) {
QueryRequest queryCtx = new QueryRequest();
public RecommendResp recommend(@PathVariable("domainId") Long domainId,
@RequestParam(value = "limit", required = false) Long limit,
HttpServletRequest request,
HttpServletResponse response) {
QueryReq queryCtx = new QueryReq();
queryCtx.setUser(UserHolder.findUser(request, response));
queryCtx.setDomainId(domainId);
return recommendService.recommend(queryCtx, limit);
}
@GetMapping("recommend/metric/{domainId}")
public RecommendResponse recommendMetricMode(@PathVariable("domainId") Long domainId,
@RequestParam(value = "limit", required = false) Long limit,
HttpServletRequest request,
HttpServletResponse response) {
QueryRequest queryCtx = new QueryRequest();
public RecommendResp recommendMetricMode(@PathVariable("domainId") Long domainId,
@RequestParam(value = "limit", required = false) Long limit,
HttpServletRequest request,
HttpServletResponse response) {
QueryReq queryCtx = new QueryReq();
queryCtx.setUser(UserHolder.findUser(request, response));
queryCtx.setDomainId(domainId);
return recommendService.recommendMetricMode(queryCtx, limit);
}
@GetMapping("recommend/question")
public List<RecommendQuestion> recommendQuestion(@RequestParam(value = "domainId", required = false) Long domainId,
HttpServletRequest request,
HttpServletResponse response) {
public List<RecommendQuestionResp> recommendQuestion(@RequestParam(value = "domainId", required = false) Long domainId,
HttpServletRequest request,
HttpServletResponse response) {
return recommendService.recommendQuestion(domainId);
}
}

View File

@@ -8,7 +8,7 @@ import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.api.pojo.response.QueryResponse;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import java.util.List;
@@ -25,8 +25,6 @@ public interface ChatService {
public void updateContext(ChatContext chatCtx);
public void updateContext(ChatContext chatCtx, QueryContext queryCtx, SemanticParseInfo semanticParseInfo);
public void switchContext(ChatContext chatCtx);
public Boolean addChat(User user, String chatName);
@@ -41,9 +39,9 @@ public interface ChatService {
Boolean deleteChat(Long chatId, String userName);
PageInfo<QueryResponse> queryInfo(PageQueryInfoReq pageQueryInfoCommend, long chatId);
PageInfo<QueryResp> queryInfo(PageQueryInfoReq pageQueryInfoCommend, long chatId);
public void addQuery(QueryResult queryResult, QueryContext queryContext, ChatContext chatCtx);
public void addQuery(QueryResult queryResult, ChatContext chatCtx);
public ChatQueryDO getLastQuery(long chatId);

View File

@@ -2,7 +2,11 @@ package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.config.*;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
import java.util.List;
@@ -14,9 +18,9 @@ public interface ConfigService {
List<ChatConfigResp> search(ChatConfigFilter filter, User user);
ChatConfigRich getConfigRichInfo(Long domainId);
ChatConfigRichResp getConfigRichInfo(Long domainId);
ChatConfigResp fetchConfigByDomainId(Long domainId);
List<ChatConfigRich> getAllChatRichConfig();
List<ChatConfigRichResp> getAllChatRichConfig();
}

View File

@@ -2,9 +2,11 @@ package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryRequest;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataRequest;
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
import org.apache.calcite.sql.parser.SqlParseException;
/***
@@ -12,9 +14,14 @@ import org.apache.calcite.sql.parser.SqlParseException;
*/
public interface QueryService {
QueryResult executeQuery(QueryRequest queryCtx) throws Exception;
ParseResp performParsing(QueryReq queryReq);
SemanticParseInfo queryContext(QueryRequest queryCtx);
QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception;
QueryResult executeQuery(QueryReq queryReq) throws Exception;
SemanticParseInfo queryContext(QueryReq queryReq);
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException;
QueryResult executeDirectQuery(QueryDataRequest queryData, User user) throws SqlParseException;
}

View File

@@ -1,9 +1,9 @@
package com.tencent.supersonic.chat.service;
import com.tencent.supersonic.chat.api.pojo.request.QueryRequest;
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestion;
import com.tencent.supersonic.chat.api.pojo.response.RecommendResponse;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
import java.util.List;
@@ -12,9 +12,9 @@ import java.util.List;
*/
public interface RecommendService {
RecommendResponse recommend(QueryRequest queryCtx, Long limit);
RecommendResp recommend(QueryReq queryCtx, Long limit);
RecommendResponse recommendMetricMode(QueryRequest queryCtx, Long limit);
RecommendResp recommendMetricMode(QueryReq queryCtx, Long limit);
List<RecommendQuestion> recommendQuestion(Long domainId);
List<RecommendQuestionResp> recommendQuestion(Long domainId);
}

Some files were not shown because too many files have changed in this diff Show More