diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java index 4b87e34ba..98a846a99 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java @@ -33,6 +33,7 @@ public class SchemaElement implements Serializable { private double order; private int isTag; private String description; + private boolean descriptionMapped; @Override public boolean equals(Object o) { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java index e23ab11aa..1ec8511ed 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementMatch.java @@ -20,4 +20,8 @@ public class SchemaElementMatch { Long frequency; boolean isInherited; + public boolean isFullMatched() { + return 1.0 == similarity; + } + } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java index 95d28a28f..06db1f387 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java @@ -1,6 +1,8 @@ package com.tencent.supersonic.headless.api.pojo; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.google.common.collect.Lists; +import org.apache.commons.collections4.CollectionUtils; import java.util.HashMap; import java.util.List; @@ -26,4 +28,25 @@ public class SchemaMapInfo { public void setMatchedElements(Long dataSet, List elementMatches) { dataSetElementMatches.put(dataSet, elementMatches); } + + @JsonIgnore + public List getTermDescriptionToMap() { + List termElements = Lists.newArrayList(); + for (Long dataSetId : getDataSetElementMatches().keySet()) { + List matchedElements = getMatchedElements(dataSetId); + for (SchemaElementMatch schemaElementMatch : matchedElements) { + if (SchemaElementType.TERM.equals(schemaElementMatch.getElement().getType()) + && schemaElementMatch.isFullMatched() + && !schemaElementMatch.getElement().isDescriptionMapped()) { + termElements.add(schemaElementMatch.getElement()); + } + } + } + return termElements; + } + + public boolean needContinueMap() { + return CollectionUtils.isNotEmpty(getTermDescriptionToMap()); + } + } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index 7be08622c..ee4a50c91 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -35,6 +35,7 @@ import java.util.stream.Collectors; public class ChatQueryContext { private String queryText; + private String oriQueryText; private Set dataSetIds; private Map> modelIdToDataSetIds; private User user; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/TermDescMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/TermDescMapper.java new file mode 100644 index 000000000..80a62fefd --- /dev/null +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/TermDescMapper.java @@ -0,0 +1,37 @@ +package com.tencent.supersonic.headless.chat.mapper; + +import com.tencent.supersonic.headless.api.pojo.SchemaElement; +import com.tencent.supersonic.headless.chat.ChatQueryContext; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.lang3.StringUtils; +import java.util.List; + +/*** + * A mapper that map the description of the term. + */ +@Slf4j +public class TermDescMapper extends BaseMapper { + + @Override + public void doMap(ChatQueryContext chatQueryContext) { + List termDescriptionToMap = chatQueryContext.getMapInfo().getTermDescriptionToMap(); + if (CollectionUtils.isEmpty(termDescriptionToMap)) { + if (StringUtils.isNotBlank(chatQueryContext.getOriQueryText())) { + chatQueryContext.setQueryText(chatQueryContext.getOriQueryText()); + } + return; + } + if (StringUtils.isBlank(chatQueryContext.getOriQueryText())) { + chatQueryContext.setOriQueryText(chatQueryContext.getQueryText()); + } + for (SchemaElement schemaElement : termDescriptionToMap) { + if (chatQueryContext.getQueryText().equals(schemaElement.getDescription())) { + schemaElement.setDescriptionMapped(true); + continue; + } + chatQueryContext.setQueryText(schemaElement.getDescription()); + } + } + +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java index d2097e431..7ede61b8a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -48,6 +48,8 @@ public class ChatWorkflowEngine { parseResult.setState(ParseResp.ParseState.FAILED); parseResult.setErrorMsg("No semantic entities can be mapped against user question."); queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); + } else if (queryCtx.getMapInfo().needContinueMap()) { + queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING); } else { queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING); } @@ -89,7 +91,8 @@ public class ChatWorkflowEngine { private void performMapping(ChatQueryContext queryCtx) { if (Objects.isNull(queryCtx.getMapInfo()) - || MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) { + || MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches()) + || queryCtx.getMapInfo().needContinueMap()) { schemaMappers.forEach(mapper -> mapper.map(queryCtx)); } } diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index a2421f8d4..967c4aec5 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -4,7 +4,8 @@ com.tencent.supersonic.headless.chat.mapper.SchemaMapper=\ com.tencent.supersonic.headless.chat.mapper.EmbeddingMapper, \ com.tencent.supersonic.headless.chat.mapper.KeywordMapper, \ com.tencent.supersonic.headless.chat.mapper.QueryFilterMapper, \ - com.tencent.supersonic.headless.chat.mapper.EntityMapper + com.tencent.supersonic.headless.chat.mapper.EntityMapper, \ + com.tencent.supersonic.headless.chat.mapper.TermDescMapper com.tencent.supersonic.headless.chat.parser.SemanticParser=\ com.tencent.supersonic.headless.chat.parser.rule.RuleSqlParser, \