mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement][chat] Optimize and modify the mapper method for terminology (#1866)
This commit is contained in:
@@ -7,11 +7,13 @@ import lombok.Builder;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class ChatApp {
|
public class ChatApp implements Serializable {
|
||||||
private String name;
|
private String name;
|
||||||
private String description;
|
private String description;
|
||||||
private String prompt;
|
private String prompt;
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ import lombok.Builder;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class Text2SQLExemplar {
|
public class Text2SQLExemplar implements Serializable {
|
||||||
|
|
||||||
public static final String PROPERTY_KEY = "sql_exemplar";
|
public static final String PROPERTY_KEY = "sql_exemplar";
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ import lombok.Data;
|
|||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class User {
|
public class User implements Serializable {
|
||||||
|
|
||||||
private Long id;
|
private Long id;
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package com.tencent.supersonic.common.util;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.SerializationUtils;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
public class DeepCopyUtil {
|
||||||
|
|
||||||
|
public static <T extends Serializable> T deepCopy(T object) {
|
||||||
|
return SerializationUtils.clone(object);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,8 +5,10 @@ import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
|
|||||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class AggregateTypeDefaultConfig {
|
public class AggregateTypeDefaultConfig implements Serializable {
|
||||||
|
|
||||||
private TimeDefaultConfig timeDefaultConfig =
|
private TimeDefaultConfig timeDefaultConfig =
|
||||||
new TimeDefaultConfig(7, DatePeriodEnum.DAY, TimeMode.RECENT);
|
new TimeDefaultConfig(7, DatePeriodEnum.DAY, TimeMode.RECENT);
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import lombok.Data;
|
|||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -14,7 +15,7 @@ import java.util.Set;
|
|||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class DataSetSchema {
|
public class DataSetSchema implements Serializable {
|
||||||
|
|
||||||
private String databaseType;
|
private String databaseType;
|
||||||
private SchemaElement dataSet;
|
private SchemaElement dataSet;
|
||||||
|
|||||||
@@ -2,11 +2,12 @@ package com.tencent.supersonic.headless.api.pojo;
|
|||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class DefaultDisplayInfo {
|
public class DefaultDisplayInfo implements Serializable {
|
||||||
|
|
||||||
// When displaying tag selection results, the information displayed by default
|
// When displaying tag selection results, the information displayed by default
|
||||||
private List<Long> dimensionIds = new ArrayList<>();
|
private List<Long> dimensionIds = new ArrayList<>();
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ package com.tencent.supersonic.headless.api.pojo;
|
|||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class DetailTypeDefaultConfig {
|
public class DetailTypeDefaultConfig implements Serializable {
|
||||||
|
|
||||||
private DefaultDisplayInfo defaultDisplayInfo;
|
private DefaultDisplayInfo defaultDisplayInfo;
|
||||||
|
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ package com.tencent.supersonic.headless.api.pojo;
|
|||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class QueryConfig {
|
public class QueryConfig implements Serializable {
|
||||||
|
|
||||||
private DetailTypeDefaultConfig detailTypeDefaultConfig = new DetailTypeDefaultConfig();
|
private DetailTypeDefaultConfig detailTypeDefaultConfig = new DetailTypeDefaultConfig();
|
||||||
|
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ import lombok.Builder;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@Builder
|
@Builder
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class RelatedSchemaElement {
|
public class RelatedSchemaElement implements Serializable {
|
||||||
|
|
||||||
private Long dimensionId;
|
private Long dimensionId;
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ public class SchemaElement implements Serializable {
|
|||||||
private double order;
|
private double order;
|
||||||
private int isTag;
|
private int isTag;
|
||||||
private String description;
|
private String description;
|
||||||
private boolean descriptionMapped;
|
|
||||||
@Builder.Default
|
@Builder.Default
|
||||||
private Map<String, Object> extInfo = new HashMap<>();
|
private Map<String, Object> extInfo = new HashMap<>();
|
||||||
private DimensionTimeTypeParams typeParams;
|
private DimensionTimeTypeParams typeParams;
|
||||||
|
|||||||
@@ -6,12 +6,14 @@ import lombok.Data;
|
|||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ToString
|
@ToString
|
||||||
@Builder
|
@Builder
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class SchemaElementMatch {
|
public class SchemaElementMatch implements Serializable {
|
||||||
private SchemaElement element;
|
private SchemaElement element;
|
||||||
private double offset;
|
private double offset;
|
||||||
private double similarity;
|
private double similarity;
|
||||||
|
|||||||
@@ -3,15 +3,17 @@ package com.tencent.supersonic.headless.api.pojo;
|
|||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
public class SchemaMapInfo {
|
public class SchemaMapInfo implements Serializable {
|
||||||
|
|
||||||
private final Map<Long, List<SchemaElementMatch>> dataSetElementMatches = new HashMap<>();
|
private final Map<Long, List<SchemaElementMatch>> dataSetElementMatches = new HashMap<>();
|
||||||
|
|
||||||
@@ -31,6 +33,23 @@ public class SchemaMapInfo {
|
|||||||
dataSetElementMatches.put(dataSet, elementMatches);
|
dataSetElementMatches.put(dataSet, elementMatches);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void addMatchedElements(SchemaMapInfo schemaMapInfo) {
|
||||||
|
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.dataSetElementMatches
|
||||||
|
.entrySet()) {
|
||||||
|
Long dataSet = entry.getKey();
|
||||||
|
List<SchemaElementMatch> newMatches = entry.getValue();
|
||||||
|
|
||||||
|
if (dataSetElementMatches.containsKey(dataSet)) {
|
||||||
|
List<SchemaElementMatch> existingMatches = dataSetElementMatches.get(dataSet);
|
||||||
|
Set<SchemaElementMatch> mergedMatches = new HashSet<>(existingMatches);
|
||||||
|
mergedMatches.addAll(newMatches);
|
||||||
|
dataSetElementMatches.put(dataSet, new ArrayList<>(mergedMatches));
|
||||||
|
} else {
|
||||||
|
dataSetElementMatches.put(dataSet, new ArrayList<>(new HashSet<>(newMatches)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@JsonIgnore
|
@JsonIgnore
|
||||||
public List<SchemaElement> getTermDescriptionToMap() {
|
public List<SchemaElement> getTermDescriptionToMap() {
|
||||||
List<SchemaElement> termElements = Lists.newArrayList();
|
List<SchemaElement> termElements = Lists.newArrayList();
|
||||||
@@ -38,16 +57,11 @@ public class SchemaMapInfo {
|
|||||||
List<SchemaElementMatch> matchedElements = getMatchedElements(dataSetId);
|
List<SchemaElementMatch> matchedElements = getMatchedElements(dataSetId);
|
||||||
for (SchemaElementMatch schemaElementMatch : matchedElements) {
|
for (SchemaElementMatch schemaElementMatch : matchedElements) {
|
||||||
if (SchemaElementType.TERM.equals(schemaElementMatch.getElement().getType())
|
if (SchemaElementType.TERM.equals(schemaElementMatch.getElement().getType())
|
||||||
&& schemaElementMatch.isFullMatched()
|
&& schemaElementMatch.isFullMatched()) {
|
||||||
&& !schemaElementMatch.getElement().isDescriptionMapped()) {
|
|
||||||
termElements.add(schemaElementMatch.getElement());
|
termElements.add(schemaElementMatch.getElement());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return termElements;
|
return termElements;
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean needContinueMap() {
|
|
||||||
return CollectionUtils.isNotEmpty(getTermDescriptionToMap());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
|||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -22,7 +23,7 @@ import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_DETAIL_LIMIT;
|
|||||||
import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT;
|
import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class SemanticParseInfo {
|
public class SemanticParseInfo implements Serializable {
|
||||||
|
|
||||||
private Integer id;
|
private Integer id;
|
||||||
private String queryMode = "PLAIN_TEXT";
|
private String queryMode = "PLAIN_TEXT";
|
||||||
|
|||||||
@@ -6,10 +6,12 @@ import lombok.AllArgsConstructor;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class TimeDefaultConfig {
|
public class TimeDefaultConfig implements Serializable {
|
||||||
|
|
||||||
/** default time span unit */
|
/** default time span unit */
|
||||||
private Integer unit = 1;
|
private Integer unit = 1;
|
||||||
|
|||||||
@@ -5,9 +5,11 @@ import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
|||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ToString(callSuper = true)
|
@ToString(callSuper = true)
|
||||||
public class QueryFilter {
|
public class QueryFilter implements Serializable {
|
||||||
|
|
||||||
private String bizName;
|
private String bizName;
|
||||||
|
|
||||||
|
|||||||
@@ -2,13 +2,14 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
|||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class QueryFilters {
|
public class QueryFilters implements Serializable {
|
||||||
private List<QueryFilter> filters = new ArrayList<>();
|
private List<QueryFilter> filters = new ArrayList<>();
|
||||||
private Map<String, Object> params = new HashMap<>();
|
private Map<String, Object> params = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,12 +11,13 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class QueryNLReq extends SemanticQueryReq {
|
public class QueryNLReq extends SemanticQueryReq implements Serializable {
|
||||||
private String queryText;
|
private String queryText;
|
||||||
private Set<Long> dataSetIds = Sets.newHashSet();
|
private Set<Long> dataSetIds = Sets.newHashSet();
|
||||||
private User user;
|
private User user;
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
|||||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -18,10 +19,9 @@ import java.util.Objects;
|
|||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ChatQueryContext {
|
public class ChatQueryContext implements Serializable {
|
||||||
|
|
||||||
private QueryNLReq request;
|
private QueryNLReq request;
|
||||||
private String oriQueryText;
|
|
||||||
private Map<Long, List<Long>> modelIdToDataSetIds;
|
private Map<Long, List<Long>> modelIdToDataSetIds;
|
||||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||||
|
|||||||
@@ -1,41 +1,44 @@
|
|||||||
package com.tencent.supersonic.headless.chat.mapper;
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.util.DeepCopyUtil;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/** * A mapper that map the description of the term. */
|
/**
|
||||||
|
* A mapper that map the description of the term.
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TermDescMapper extends BaseMapper {
|
public class TermDescMapper extends BaseMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doMap(ChatQueryContext chatQueryContext) {
|
public void doMap(ChatQueryContext chatQueryContext) {
|
||||||
List<SchemaElement> termDescriptionToMap =
|
SchemaMapInfo mapInfo = chatQueryContext.getMapInfo();
|
||||||
chatQueryContext.getMapInfo().getTermDescriptionToMap();
|
List<SchemaElement> termElements = mapInfo.getTermDescriptionToMap();
|
||||||
if (CollectionUtils.isEmpty(termDescriptionToMap)) {
|
if (CollectionUtils.isEmpty(termElements)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (StringUtils.isBlank(chatQueryContext.getOriQueryText())) {
|
for (SchemaElement schemaElement : termElements) {
|
||||||
chatQueryContext.setOriQueryText(chatQueryContext.getRequest().getQueryText());
|
ChatQueryContext queryCtx =
|
||||||
}
|
buildQueryContext(chatQueryContext, schemaElement.getDescription());
|
||||||
for (SchemaElement schemaElement : termDescriptionToMap) {
|
ComponentFactory.getSchemaMappers().forEach(mapper -> mapper.map(queryCtx));
|
||||||
if (schemaElement.isDescriptionMapped()) {
|
chatQueryContext.getMapInfo().addMatchedElements(queryCtx.getMapInfo());
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (chatQueryContext.getRequest().getQueryText()
|
|
||||||
.equals(schemaElement.getDescription())) {
|
|
||||||
schemaElement.setDescriptionMapped(true);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
chatQueryContext.getRequest().setQueryText(schemaElement.getDescription());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if (CollectionUtils.isEmpty(chatQueryContext.getMapInfo().getTermDescriptionToMap())) {
|
|
||||||
chatQueryContext.getRequest().setQueryText(chatQueryContext.getOriQueryText());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static ChatQueryContext buildQueryContext(ChatQueryContext chatQueryContext,
|
||||||
|
String queryText) {
|
||||||
|
ChatQueryContext queryContext = DeepCopyUtil.deepCopy(chatQueryContext);
|
||||||
|
queryContext.getRequest().setQueryText(queryText);
|
||||||
|
queryContext.setMapInfo(new SchemaMapInfo());
|
||||||
|
queryContext.setSemanticSchema(chatQueryContext.getSemanticSchema());
|
||||||
|
queryContext.setModelIdToDataSetIds(chatQueryContext.getModelIdToDataSetIds());
|
||||||
|
queryContext.setChatWorkflowState(chatQueryContext.getChatWorkflowState());
|
||||||
|
return queryContext;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,19 +1,45 @@
|
|||||||
package com.tencent.supersonic.headless.chat.utils;
|
package com.tencent.supersonic.headless.chat.utils;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
|
||||||
|
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
|
||||||
|
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||||
import com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver;
|
import com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
/** HeadlessConverter QueryOptimizer QueryExecutor object factory */
|
/**
|
||||||
|
* QueryConverter QueryOptimizer QueryExecutor object factory
|
||||||
|
*/
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ComponentFactory {
|
public class ComponentFactory {
|
||||||
|
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
|
||||||
|
private static List<SemanticParser> semanticParsers = new ArrayList<>();
|
||||||
|
private static List<SemanticCorrector> semanticCorrectors = new ArrayList<>();
|
||||||
private static DataSetResolver modelResolver;
|
private static DataSetResolver modelResolver;
|
||||||
|
|
||||||
|
public static List<SchemaMapper> getSchemaMappers() {
|
||||||
|
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers)
|
||||||
|
: schemaMappers;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<SemanticParser> getSemanticParsers() {
|
||||||
|
return CollectionUtils.isEmpty(semanticParsers)
|
||||||
|
? init(SemanticParser.class, semanticParsers)
|
||||||
|
: semanticParsers;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<SemanticCorrector> getSemanticCorrectors() {
|
||||||
|
return CollectionUtils.isEmpty(semanticCorrectors)
|
||||||
|
? init(SemanticCorrector.class, semanticCorrectors)
|
||||||
|
: semanticCorrectors;
|
||||||
|
}
|
||||||
|
|
||||||
public static DataSetResolver getModelResolver() {
|
public static DataSetResolver getModelResolver() {
|
||||||
if (Objects.isNull(modelResolver)) {
|
if (Objects.isNull(modelResolver)) {
|
||||||
modelResolver = init(DataSetResolver.class);
|
modelResolver = init(DataSetResolver.class);
|
||||||
@@ -25,13 +51,13 @@ public class ComponentFactory {
|
|||||||
return ContextUtils.getContext().getBean(name, tClass);
|
return ContextUtils.getContext().getBean(name, tClass);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static <T> List<T> init(Class<T> factoryType, List list) {
|
protected static <T> List<T> init(Class<T> factoryType, List list) {
|
||||||
list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
|
list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
|
||||||
Thread.currentThread().getContextClassLoader()));
|
Thread.currentThread().getContextClassLoader()));
|
||||||
return list;
|
return list;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static <T> T init(Class<T> factoryType) {
|
protected static <T> T init(Class<T> factoryType) {
|
||||||
return SpringFactoriesLoader
|
return SpringFactoriesLoader
|
||||||
.loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0);
|
.loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,12 +29,12 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
|||||||
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
|
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
|
||||||
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
|
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
||||||
|
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||||
import com.tencent.supersonic.headless.server.service.DataSetService;
|
import com.tencent.supersonic.headless.server.service.DataSetService;
|
||||||
import com.tencent.supersonic.headless.server.service.RetrieveService;
|
import com.tencent.supersonic.headless.server.service.RetrieveService;
|
||||||
import com.tencent.supersonic.headless.server.service.SchemaService;
|
import com.tencent.supersonic.headless.server.service.SchemaService;
|
||||||
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
|
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
|
||||||
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ import com.tencent.supersonic.headless.server.service.DomainService;
|
|||||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||||
import com.tencent.supersonic.headless.server.service.ModelRelaService;
|
import com.tencent.supersonic.headless.server.service.ModelRelaService;
|
||||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||||
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
|
import com.tencent.supersonic.headless.server.utils.CoreComponentFactory;
|
||||||
import com.tencent.supersonic.headless.server.utils.ModelConverter;
|
import com.tencent.supersonic.headless.server.utils.ModelConverter;
|
||||||
import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
|
import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -222,7 +222,7 @@ public class ModelServiceImpl implements ModelService {
|
|||||||
|
|
||||||
private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas,
|
private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas,
|
||||||
Map<String, ModelSchema> modelSchemaMap) {
|
Map<String, ModelSchema> modelSchemaMap) {
|
||||||
SemanticModeller semanticModeller = ComponentFactory.getSemanticModeller();
|
SemanticModeller semanticModeller = CoreComponentFactory.getSemanticModeller();
|
||||||
ModelSchema modelSchema = semanticModeller.build(curSchema, dbSchemas, modelBuildReq);
|
ModelSchema modelSchema = semanticModeller.build(curSchema, dbSchemas, modelBuildReq);
|
||||||
modelSchemaMap.put(curSchema.getTable(), modelSchema);
|
modelSchemaMap.put(curSchema.getTable(), modelSchema);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -30,11 +30,12 @@ import java.util.stream.Collectors;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class ChatWorkflowEngine {
|
public class ChatWorkflowEngine {
|
||||||
|
|
||||||
private final List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
|
private final List<SchemaMapper> schemaMappers = CoreComponentFactory.getSchemaMappers();
|
||||||
private final List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
|
private final List<SemanticParser> semanticParsers = CoreComponentFactory.getSemanticParsers();
|
||||||
private final List<SemanticCorrector> semanticCorrectors =
|
private final List<SemanticCorrector> semanticCorrectors =
|
||||||
ComponentFactory.getSemanticCorrectors();
|
CoreComponentFactory.getSemanticCorrectors();
|
||||||
private final List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
|
private final List<ResultProcessor> resultProcessors =
|
||||||
|
CoreComponentFactory.getResultProcessors();
|
||||||
|
|
||||||
public void start(ChatWorkflowState initialState, ChatQueryContext queryCtx,
|
public void start(ChatWorkflowState initialState, ChatQueryContext queryCtx,
|
||||||
ParseResp parseResult) {
|
ParseResp parseResult) {
|
||||||
@@ -48,8 +49,6 @@ public class ChatWorkflowEngine {
|
|||||||
parseResult.setErrorMsg(
|
parseResult.setErrorMsg(
|
||||||
"No semantic entities can be mapped against user question.");
|
"No semantic entities can be mapped against user question.");
|
||||||
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
||||||
} else if (queryCtx.getMapInfo().needContinueMap()) {
|
|
||||||
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
|
|
||||||
} else {
|
} else {
|
||||||
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING);
|
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING);
|
||||||
}
|
}
|
||||||
@@ -91,8 +90,7 @@ public class ChatWorkflowEngine {
|
|||||||
|
|
||||||
private void performMapping(ChatQueryContext queryCtx) {
|
private void performMapping(ChatQueryContext queryCtx) {
|
||||||
if (Objects.isNull(queryCtx.getMapInfo())
|
if (Objects.isNull(queryCtx.getMapInfo())
|
||||||
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())
|
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) {
|
||||||
|| queryCtx.getMapInfo().needContinueMap()) {
|
|
||||||
schemaMappers.forEach(mapper -> mapper.map(queryCtx));
|
schemaMappers.forEach(mapper -> mapper.map(queryCtx));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,66 +0,0 @@
|
|||||||
package com.tencent.supersonic.headless.server.utils;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
|
|
||||||
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
|
|
||||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
|
||||||
import com.tencent.supersonic.headless.server.modeller.SemanticModeller;
|
|
||||||
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/** QueryConverter QueryOptimizer QueryExecutor object factory */
|
|
||||||
@Slf4j
|
|
||||||
public class ComponentFactory {
|
|
||||||
private static List<ResultProcessor> resultProcessors = new ArrayList<>();
|
|
||||||
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
|
|
||||||
private static List<SemanticParser> semanticParsers = new ArrayList<>();
|
|
||||||
private static List<SemanticCorrector> semanticCorrectors = new ArrayList<>();
|
|
||||||
private static SemanticModeller semanticModeller;
|
|
||||||
|
|
||||||
public static List<ResultProcessor> getResultProcessors() {
|
|
||||||
return CollectionUtils.isEmpty(resultProcessors)
|
|
||||||
? init(ResultProcessor.class, resultProcessors)
|
|
||||||
: resultProcessors;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static List<SchemaMapper> getSchemaMappers() {
|
|
||||||
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers)
|
|
||||||
: schemaMappers;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static List<SemanticParser> getSemanticParsers() {
|
|
||||||
return CollectionUtils.isEmpty(semanticParsers)
|
|
||||||
? init(SemanticParser.class, semanticParsers)
|
|
||||||
: semanticParsers;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static List<SemanticCorrector> getSemanticCorrectors() {
|
|
||||||
return CollectionUtils.isEmpty(semanticCorrectors)
|
|
||||||
? init(SemanticCorrector.class, semanticCorrectors)
|
|
||||||
: semanticCorrectors;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static SemanticModeller getSemanticModeller() {
|
|
||||||
return semanticModeller == null ? init(SemanticModeller.class) : semanticModeller;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static <T> T getBean(String name, Class<T> tClass) {
|
|
||||||
return ContextUtils.getContext().getBean(name, tClass);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static <T> List<T> init(Class<T> factoryType, List list) {
|
|
||||||
list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
|
|
||||||
Thread.currentThread().getContextClassLoader()));
|
|
||||||
return list;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static <T> T init(Class<T> factoryType) {
|
|
||||||
return SpringFactoriesLoader
|
|
||||||
.loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
package com.tencent.supersonic.headless.server.utils;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||||
|
import com.tencent.supersonic.headless.server.modeller.SemanticModeller;
|
||||||
|
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* QueryConverter QueryOptimizer QueryExecutor object factory
|
||||||
|
*/
|
||||||
|
@Slf4j
|
||||||
|
public class CoreComponentFactory extends ComponentFactory {
|
||||||
|
|
||||||
|
private static List<ResultProcessor> resultProcessors = new ArrayList<>();
|
||||||
|
|
||||||
|
private static SemanticModeller semanticModeller;
|
||||||
|
|
||||||
|
public static List<ResultProcessor> getResultProcessors() {
|
||||||
|
return CollectionUtils.isEmpty(resultProcessors)
|
||||||
|
? init(ResultProcessor.class, resultProcessors)
|
||||||
|
: resultProcessors;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static SemanticModeller getSemanticModeller() {
|
||||||
|
return semanticModeller == null ? init(SemanticModeller.class) : semanticModeller;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -59,8 +59,7 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
durations.add(System.currentTimeMillis() - start);
|
durations.add(System.currentTimeMillis() - start);
|
||||||
assert result.getQueryColumns().size() == 2;
|
assert result.getQueryColumns().size() == 2;
|
||||||
assert result.getQueryResults().size() == 30;
|
assert result.getQueryResults().size() == 30;
|
||||||
assert result.getTextResult().contains("date")
|
assert result.getTextResult().contains("date") || result.getTextResult().contains("日期");
|
||||||
|| result.getTextResult().contains("日期");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|||||||
Reference in New Issue
Block a user