From cf359f3e2ffe2315d4f976aebaa222d2d253a9bb Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Thu, 31 Oct 2024 11:18:35 +0800 Subject: [PATCH] [improvement][chat] Optimize and modify the mapper method for terminology (#1866) --- .../supersonic/common/pojo/ChatApp.java | 4 +- .../common/pojo/Text2SQLExemplar.java | 4 +- .../tencent/supersonic/common/pojo/User.java | 4 +- .../supersonic/common/util/DeepCopyUtil.java | 12 ++++ .../api/pojo/AggregateTypeDefaultConfig.java | 4 +- .../headless/api/pojo/DataSetSchema.java | 3 +- .../headless/api/pojo/DefaultDisplayInfo.java | 3 +- .../api/pojo/DetailTypeDefaultConfig.java | 4 +- .../headless/api/pojo/QueryConfig.java | 4 +- .../api/pojo/RelatedSchemaElement.java | 4 +- .../headless/api/pojo/SchemaElement.java | 1 - .../headless/api/pojo/SchemaElementMatch.java | 4 +- .../headless/api/pojo/SchemaMapInfo.java | 30 ++++++--- .../headless/api/pojo/SemanticParseInfo.java | 3 +- .../headless/api/pojo/TimeDefaultConfig.java | 4 +- .../api/pojo/request/QueryFilter.java | 4 +- .../api/pojo/request/QueryFilters.java | 3 +- .../headless/api/pojo/request/QueryNLReq.java | 3 +- .../headless/chat/ChatQueryContext.java | 4 +- .../headless/chat/mapper/TermDescMapper.java | 47 ++++++------- .../headless/chat/utils/ComponentFactory.java | 34 ++++++++-- .../service/impl/S2ChatLayerService.java | 2 +- .../server/service/impl/ModelServiceImpl.java | 4 +- .../server/utils/ChatWorkflowEngine.java | 14 ++-- .../server/utils/ComponentFactory.java | 66 ------------------- .../server/utils/CoreComponentFactory.java | 31 +++++++++ .../supersonic/evaluation/Text2SQLEval.java | 3 +- 27 files changed, 172 insertions(+), 131 deletions(-) create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/DeepCopyUtil.java delete mode 100644 headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ComponentFactory.java create mode 100644 headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/CoreComponentFactory.java diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatApp.java b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatApp.java index 0a9e8e93b..18b2d6cb3 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/ChatApp.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/ChatApp.java @@ -7,11 +7,13 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; +import java.io.Serializable; + @Data @Builder @AllArgsConstructor @NoArgsConstructor -public class ChatApp { +public class ChatApp implements Serializable { private String name; private String description; private String prompt; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/Text2SQLExemplar.java b/common/src/main/java/com/tencent/supersonic/common/pojo/Text2SQLExemplar.java index e24ab979a..d878c13c2 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/Text2SQLExemplar.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/Text2SQLExemplar.java @@ -5,11 +5,13 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; +import java.io.Serializable; + @Data @Builder @NoArgsConstructor @AllArgsConstructor -public class Text2SQLExemplar { +public class Text2SQLExemplar implements Serializable { public static final String PROPERTY_KEY = "sql_exemplar"; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/User.java b/common/src/main/java/com/tencent/supersonic/common/pojo/User.java index 96978810e..b85194984 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/User.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/User.java @@ -5,10 +5,12 @@ import lombok.Data; import lombok.NoArgsConstructor; import org.apache.commons.lang3.StringUtils; +import java.io.Serializable; + @Data @NoArgsConstructor @AllArgsConstructor -public class User { +public class User implements Serializable { private Long id; diff --git a/common/src/main/java/com/tencent/supersonic/common/util/DeepCopyUtil.java b/common/src/main/java/com/tencent/supersonic/common/util/DeepCopyUtil.java new file mode 100644 index 000000000..93b3035d8 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/DeepCopyUtil.java @@ -0,0 +1,12 @@ +package com.tencent.supersonic.common.util; + +import org.apache.commons.lang3.SerializationUtils; + +import java.io.Serializable; + +public class DeepCopyUtil { + + public static T deepCopy(T object) { + return SerializationUtils.clone(object); + } +} diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/AggregateTypeDefaultConfig.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/AggregateTypeDefaultConfig.java index f2f9db72e..ad7570f8f 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/AggregateTypeDefaultConfig.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/AggregateTypeDefaultConfig.java @@ -5,8 +5,10 @@ import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum; import com.tencent.supersonic.common.pojo.enums.TimeMode; import lombok.Data; +import java.io.Serializable; + @Data -public class AggregateTypeDefaultConfig { +public class AggregateTypeDefaultConfig implements Serializable { private TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig(7, DatePeriodEnum.DAY, TimeMode.RECENT); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java index 76b9c9f87..cc26aa74c 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java @@ -4,6 +4,7 @@ import lombok.Data; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; +import java.io.Serializable; import java.util.ArrayList; import java.util.HashSet; import java.util.List; @@ -14,7 +15,7 @@ import java.util.Set; import java.util.stream.Collectors; @Data -public class DataSetSchema { +public class DataSetSchema implements Serializable { private String databaseType; private SchemaElement dataSet; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DefaultDisplayInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DefaultDisplayInfo.java index b54aec90d..ef2d8e2e1 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DefaultDisplayInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DefaultDisplayInfo.java @@ -2,11 +2,12 @@ package com.tencent.supersonic.headless.api.pojo; import lombok.Data; +import java.io.Serializable; import java.util.ArrayList; import java.util.List; @Data -public class DefaultDisplayInfo { +public class DefaultDisplayInfo implements Serializable { // When displaying tag selection results, the information displayed by default private List dimensionIds = new ArrayList<>(); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DetailTypeDefaultConfig.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DetailTypeDefaultConfig.java index 0a59eb934..7977db800 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DetailTypeDefaultConfig.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DetailTypeDefaultConfig.java @@ -3,8 +3,10 @@ package com.tencent.supersonic.headless.api.pojo; import com.tencent.supersonic.common.pojo.Constants; import lombok.Data; +import java.io.Serializable; + @Data -public class DetailTypeDefaultConfig { +public class DetailTypeDefaultConfig implements Serializable { private DefaultDisplayInfo defaultDisplayInfo; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/QueryConfig.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/QueryConfig.java index d78daefc4..563542640 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/QueryConfig.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/QueryConfig.java @@ -2,8 +2,10 @@ package com.tencent.supersonic.headless.api.pojo; import lombok.Data; +import java.io.Serializable; + @Data -public class QueryConfig { +public class QueryConfig implements Serializable { private DetailTypeDefaultConfig detailTypeDefaultConfig = new DetailTypeDefaultConfig(); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/RelatedSchemaElement.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/RelatedSchemaElement.java index 77bbd12eb..e383906c8 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/RelatedSchemaElement.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/RelatedSchemaElement.java @@ -5,11 +5,13 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; +import java.io.Serializable; + @Data @Builder @NoArgsConstructor @AllArgsConstructor -public class RelatedSchemaElement { +public class RelatedSchemaElement implements Serializable { private Long dimensionId; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java index 3d3156fda..7cb45b13e 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java @@ -39,7 +39,6 @@ public class SchemaElement implements Serializable { private double order; private int isTag; private String description; - private boolean descriptionMapped; @Builder.Default private Map extInfo = new HashMap<>(); private DimensionTimeTypeParams typeParams; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java index 95e6e688f..42f30000f 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java @@ -6,12 +6,14 @@ import lombok.Data; import lombok.NoArgsConstructor; import lombok.ToString; +import java.io.Serializable; + @Data @ToString @Builder @AllArgsConstructor @NoArgsConstructor -public class SchemaElementMatch { +public class SchemaElementMatch implements Serializable { private SchemaElement element; private double offset; private double similarity; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java index a6661af67..1ecf0f6ee 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java @@ -3,15 +3,17 @@ package com.tencent.supersonic.headless.api.pojo; import com.fasterxml.jackson.annotation.JsonIgnore; import com.google.common.collect.Lists; import lombok.Getter; -import org.apache.commons.collections4.CollectionUtils; +import java.io.Serializable; +import java.util.ArrayList; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @Getter -public class SchemaMapInfo { +public class SchemaMapInfo implements Serializable { private final Map> dataSetElementMatches = new HashMap<>(); @@ -31,6 +33,23 @@ public class SchemaMapInfo { dataSetElementMatches.put(dataSet, elementMatches); } + public void addMatchedElements(SchemaMapInfo schemaMapInfo) { + for (Map.Entry> entry : schemaMapInfo.dataSetElementMatches + .entrySet()) { + Long dataSet = entry.getKey(); + List newMatches = entry.getValue(); + + if (dataSetElementMatches.containsKey(dataSet)) { + List existingMatches = dataSetElementMatches.get(dataSet); + Set mergedMatches = new HashSet<>(existingMatches); + mergedMatches.addAll(newMatches); + dataSetElementMatches.put(dataSet, new ArrayList<>(mergedMatches)); + } else { + dataSetElementMatches.put(dataSet, new ArrayList<>(new HashSet<>(newMatches))); + } + } + } + @JsonIgnore public List getTermDescriptionToMap() { List termElements = Lists.newArrayList(); @@ -38,16 +57,11 @@ public class SchemaMapInfo { List matchedElements = getMatchedElements(dataSetId); for (SchemaElementMatch schemaElementMatch : matchedElements) { if (SchemaElementType.TERM.equals(schemaElementMatch.getElement().getType()) - && schemaElementMatch.isFullMatched() - && !schemaElementMatch.getElement().isDescriptionMapped()) { + && schemaElementMatch.isFullMatched()) { termElements.add(schemaElementMatch.getElement()); } } } return termElements; } - - public boolean needContinueMap() { - return CollectionUtils.isNotEmpty(getTermDescriptionToMap()); - } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java index 8645f7776..92e33d5c7 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java @@ -12,6 +12,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import lombok.Builder; import lombok.Data; +import java.io.Serializable; import java.util.Comparator; import java.util.List; import java.util.Map; @@ -22,7 +23,7 @@ import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_DETAIL_LIMIT; import static com.tencent.supersonic.common.pojo.Constants.DEFAULT_METRIC_LIMIT; @Data -public class SemanticParseInfo { +public class SemanticParseInfo implements Serializable { private Integer id; private String queryMode = "PLAIN_TEXT"; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/TimeDefaultConfig.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/TimeDefaultConfig.java index 5fa5d2917..282a98bf1 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/TimeDefaultConfig.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/TimeDefaultConfig.java @@ -6,10 +6,12 @@ import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; +import java.io.Serializable; + @Data @AllArgsConstructor @NoArgsConstructor -public class TimeDefaultConfig { +public class TimeDefaultConfig implements Serializable { /** default time span unit */ private Integer unit = 1; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryFilter.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryFilter.java index bdfbc0427..b06814aef 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryFilter.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryFilter.java @@ -5,9 +5,11 @@ import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import lombok.Data; import lombok.ToString; +import java.io.Serializable; + @Data @ToString(callSuper = true) -public class QueryFilter { +public class QueryFilter implements Serializable { private String bizName; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryFilters.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryFilters.java index 1323254a6..0977690d9 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryFilters.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryFilters.java @@ -2,13 +2,14 @@ package com.tencent.supersonic.headless.api.pojo.request; import lombok.Data; +import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @Data -public class QueryFilters { +public class QueryFilters implements Serializable { private List filters = new ArrayList<>(); private Map params = new HashMap<>(); } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java index 83116f642..990effc72 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java @@ -11,12 +11,13 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import lombok.Data; +import java.io.Serializable; import java.util.List; import java.util.Map; import java.util.Set; @Data -public class QueryNLReq extends SemanticQueryReq { +public class QueryNLReq extends SemanticQueryReq implements Serializable { private String queryText; private Set dataSetIds = Sets.newHashSet(); private User user; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index 1d585ff99..a55e0c409 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -10,6 +10,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.chat.query.SemanticQuery; import lombok.Data; +import java.io.Serializable; import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -18,10 +19,9 @@ import java.util.Objects; import java.util.stream.Collectors; @Data -public class ChatQueryContext { +public class ChatQueryContext implements Serializable { private QueryNLReq request; - private String oriQueryText; private Map> modelIdToDataSetIds; private List candidateQueries = new ArrayList<>(); private SchemaMapInfo mapInfo = new SchemaMapInfo(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/TermDescMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/TermDescMapper.java index e37136327..dbb79d722 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/TermDescMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/TermDescMapper.java @@ -1,41 +1,44 @@ package com.tencent.supersonic.headless.chat.mapper; +import com.tencent.supersonic.common.util.DeepCopyUtil; import com.tencent.supersonic.headless.api.pojo.SchemaElement; +import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.chat.ChatQueryContext; +import com.tencent.supersonic.headless.chat.utils.ComponentFactory; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.CollectionUtils; -import org.apache.commons.lang3.StringUtils; import java.util.List; -/** * A mapper that map the description of the term. */ +/** + * A mapper that map the description of the term. + */ @Slf4j public class TermDescMapper extends BaseMapper { @Override public void doMap(ChatQueryContext chatQueryContext) { - List termDescriptionToMap = - chatQueryContext.getMapInfo().getTermDescriptionToMap(); - if (CollectionUtils.isEmpty(termDescriptionToMap)) { + SchemaMapInfo mapInfo = chatQueryContext.getMapInfo(); + List termElements = mapInfo.getTermDescriptionToMap(); + if (CollectionUtils.isEmpty(termElements)) { return; } - if (StringUtils.isBlank(chatQueryContext.getOriQueryText())) { - chatQueryContext.setOriQueryText(chatQueryContext.getRequest().getQueryText()); - } - for (SchemaElement schemaElement : termDescriptionToMap) { - if (schemaElement.isDescriptionMapped()) { - continue; - } - if (chatQueryContext.getRequest().getQueryText() - .equals(schemaElement.getDescription())) { - schemaElement.setDescriptionMapped(true); - continue; - } - chatQueryContext.getRequest().setQueryText(schemaElement.getDescription()); - break; - } - if (CollectionUtils.isEmpty(chatQueryContext.getMapInfo().getTermDescriptionToMap())) { - chatQueryContext.getRequest().setQueryText(chatQueryContext.getOriQueryText()); + for (SchemaElement schemaElement : termElements) { + ChatQueryContext queryCtx = + buildQueryContext(chatQueryContext, schemaElement.getDescription()); + ComponentFactory.getSchemaMappers().forEach(mapper -> mapper.map(queryCtx)); + chatQueryContext.getMapInfo().addMatchedElements(queryCtx.getMapInfo()); } } + + private static ChatQueryContext buildQueryContext(ChatQueryContext chatQueryContext, + String queryText) { + ChatQueryContext queryContext = DeepCopyUtil.deepCopy(chatQueryContext); + queryContext.getRequest().setQueryText(queryText); + queryContext.setMapInfo(new SchemaMapInfo()); + queryContext.setSemanticSchema(chatQueryContext.getSemanticSchema()); + queryContext.setModelIdToDataSetIds(chatQueryContext.getModelIdToDataSetIds()); + queryContext.setChatWorkflowState(chatQueryContext.getChatWorkflowState()); + return queryContext; + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/ComponentFactory.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/ComponentFactory.java index 92f0c9717..9b0ed918b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/ComponentFactory.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/ComponentFactory.java @@ -1,19 +1,45 @@ package com.tencent.supersonic.headless.chat.utils; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector; +import com.tencent.supersonic.headless.chat.mapper.SchemaMapper; +import com.tencent.supersonic.headless.chat.parser.SemanticParser; import com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; import org.springframework.core.io.support.SpringFactoriesLoader; +import java.util.ArrayList; import java.util.List; import java.util.Objects; -/** HeadlessConverter QueryOptimizer QueryExecutor object factory */ +/** + * QueryConverter QueryOptimizer QueryExecutor object factory + */ @Slf4j public class ComponentFactory { - + private static List schemaMappers = new ArrayList<>(); + private static List semanticParsers = new ArrayList<>(); + private static List semanticCorrectors = new ArrayList<>(); private static DataSetResolver modelResolver; + public static List getSchemaMappers() { + return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) + : schemaMappers; + } + + public static List getSemanticParsers() { + return CollectionUtils.isEmpty(semanticParsers) + ? init(SemanticParser.class, semanticParsers) + : semanticParsers; + } + + public static List getSemanticCorrectors() { + return CollectionUtils.isEmpty(semanticCorrectors) + ? init(SemanticCorrector.class, semanticCorrectors) + : semanticCorrectors; + } + public static DataSetResolver getModelResolver() { if (Objects.isNull(modelResolver)) { modelResolver = init(DataSetResolver.class); @@ -25,13 +51,13 @@ public class ComponentFactory { return ContextUtils.getContext().getBean(name, tClass); } - private static List init(Class factoryType, List list) { + protected static List init(Class factoryType, List list) { list.addAll(SpringFactoriesLoader.loadFactories(factoryType, Thread.currentThread().getContextClassLoader())); return list; } - private static T init(Class factoryType) { + protected static T init(Class factoryType) { return SpringFactoriesLoader .loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java index 22305b8d9..ff4c4541a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java @@ -29,12 +29,12 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector; import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector; import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder; +import com.tencent.supersonic.headless.chat.utils.ComponentFactory; import com.tencent.supersonic.headless.server.facade.service.ChatLayerService; import com.tencent.supersonic.headless.server.service.DataSetService; import com.tencent.supersonic.headless.server.service.RetrieveService; import com.tencent.supersonic.headless.server.service.SchemaService; import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine; -import com.tencent.supersonic.headless.server.utils.ComponentFactory; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.springframework.beans.BeanUtils; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java index c4d0e5fc1..75594c745 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java @@ -45,7 +45,7 @@ import com.tencent.supersonic.headless.server.service.DomainService; import com.tencent.supersonic.headless.server.service.MetricService; import com.tencent.supersonic.headless.server.service.ModelRelaService; import com.tencent.supersonic.headless.server.service.ModelService; -import com.tencent.supersonic.headless.server.utils.ComponentFactory; +import com.tencent.supersonic.headless.server.utils.CoreComponentFactory; import com.tencent.supersonic.headless.server.utils.ModelConverter; import com.tencent.supersonic.headless.server.utils.NameCheckUtils; import lombok.extern.slf4j.Slf4j; @@ -222,7 +222,7 @@ public class ModelServiceImpl implements ModelService { private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List dbSchemas, Map modelSchemaMap) { - SemanticModeller semanticModeller = ComponentFactory.getSemanticModeller(); + SemanticModeller semanticModeller = CoreComponentFactory.getSemanticModeller(); ModelSchema modelSchema = semanticModeller.build(curSchema, dbSchemas, modelBuildReq); modelSchemaMap.put(curSchema.getTable(), modelSchema); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java index 8308d9655..3ea151549 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -30,11 +30,12 @@ import java.util.stream.Collectors; @Slf4j public class ChatWorkflowEngine { - private final List schemaMappers = ComponentFactory.getSchemaMappers(); - private final List semanticParsers = ComponentFactory.getSemanticParsers(); + private final List schemaMappers = CoreComponentFactory.getSchemaMappers(); + private final List semanticParsers = CoreComponentFactory.getSemanticParsers(); private final List semanticCorrectors = - ComponentFactory.getSemanticCorrectors(); - private final List resultProcessors = ComponentFactory.getResultProcessors(); + CoreComponentFactory.getSemanticCorrectors(); + private final List resultProcessors = + CoreComponentFactory.getResultProcessors(); public void start(ChatWorkflowState initialState, ChatQueryContext queryCtx, ParseResp parseResult) { @@ -48,8 +49,6 @@ public class ChatWorkflowEngine { parseResult.setErrorMsg( "No semantic entities can be mapped against user question."); queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); - } else if (queryCtx.getMapInfo().needContinueMap()) { - queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING); } else { queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING); } @@ -91,8 +90,7 @@ public class ChatWorkflowEngine { private void performMapping(ChatQueryContext queryCtx) { if (Objects.isNull(queryCtx.getMapInfo()) - || MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches()) - || queryCtx.getMapInfo().needContinueMap()) { + || MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) { schemaMappers.forEach(mapper -> mapper.map(queryCtx)); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ComponentFactory.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ComponentFactory.java deleted file mode 100644 index c882d5784..000000000 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ComponentFactory.java +++ /dev/null @@ -1,66 +0,0 @@ -package com.tencent.supersonic.headless.server.utils; - -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector; -import com.tencent.supersonic.headless.chat.mapper.SchemaMapper; -import com.tencent.supersonic.headless.chat.parser.SemanticParser; -import com.tencent.supersonic.headless.server.modeller.SemanticModeller; -import com.tencent.supersonic.headless.server.processor.ResultProcessor; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections.CollectionUtils; -import org.springframework.core.io.support.SpringFactoriesLoader; - -import java.util.ArrayList; -import java.util.List; - -/** QueryConverter QueryOptimizer QueryExecutor object factory */ -@Slf4j -public class ComponentFactory { - private static List resultProcessors = new ArrayList<>(); - private static List schemaMappers = new ArrayList<>(); - private static List semanticParsers = new ArrayList<>(); - private static List semanticCorrectors = new ArrayList<>(); - private static SemanticModeller semanticModeller; - - public static List getResultProcessors() { - return CollectionUtils.isEmpty(resultProcessors) - ? init(ResultProcessor.class, resultProcessors) - : resultProcessors; - } - - public static List getSchemaMappers() { - return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) - : schemaMappers; - } - - public static List getSemanticParsers() { - return CollectionUtils.isEmpty(semanticParsers) - ? init(SemanticParser.class, semanticParsers) - : semanticParsers; - } - - public static List getSemanticCorrectors() { - return CollectionUtils.isEmpty(semanticCorrectors) - ? init(SemanticCorrector.class, semanticCorrectors) - : semanticCorrectors; - } - - public static SemanticModeller getSemanticModeller() { - return semanticModeller == null ? init(SemanticModeller.class) : semanticModeller; - } - - public static T getBean(String name, Class tClass) { - return ContextUtils.getContext().getBean(name, tClass); - } - - private static List init(Class factoryType, List list) { - list.addAll(SpringFactoriesLoader.loadFactories(factoryType, - Thread.currentThread().getContextClassLoader())); - return list; - } - - private static T init(Class factoryType) { - return SpringFactoriesLoader - .loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0); - } -} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/CoreComponentFactory.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/CoreComponentFactory.java new file mode 100644 index 000000000..1845ee8dd --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/CoreComponentFactory.java @@ -0,0 +1,31 @@ +package com.tencent.supersonic.headless.server.utils; + +import com.tencent.supersonic.headless.chat.utils.ComponentFactory; +import com.tencent.supersonic.headless.server.modeller.SemanticModeller; +import com.tencent.supersonic.headless.server.processor.ResultProcessor; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; + +import java.util.ArrayList; +import java.util.List; + +/** + * QueryConverter QueryOptimizer QueryExecutor object factory + */ +@Slf4j +public class CoreComponentFactory extends ComponentFactory { + + private static List resultProcessors = new ArrayList<>(); + + private static SemanticModeller semanticModeller; + + public static List getResultProcessors() { + return CollectionUtils.isEmpty(resultProcessors) + ? init(ResultProcessor.class, resultProcessors) + : resultProcessors; + } + + public static SemanticModeller getSemanticModeller() { + return semanticModeller == null ? init(SemanticModeller.class) : semanticModeller; + } +} diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index bc713c125..21a6869c6 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -59,8 +59,7 @@ public class Text2SQLEval extends BaseTest { durations.add(System.currentTimeMillis() - start); assert result.getQueryColumns().size() == 2; assert result.getQueryResults().size() == 30; - assert result.getTextResult().contains("date") - || result.getTextResult().contains("日期"); + assert result.getTextResult().contains("date") || result.getTextResult().contains("日期"); } @Test