(improvement)(headless) Add semantic retrieval to term descriptions and extract relevant semantic information (#1468)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-07-29 09:15:18 +08:00
committed by GitHub
parent ccd79e4830
commit 26f682cc45
7 changed files with 72 additions and 2 deletions

View File

@@ -33,6 +33,7 @@ public class SchemaElement implements Serializable {
private double order; private double order;
private int isTag; private int isTag;
private String description; private String description;
private boolean descriptionMapped;
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {

View File

@@ -20,4 +20,8 @@ public class SchemaElementMatch {
Long frequency; Long frequency;
boolean isInherited; boolean isInherited;
public boolean isFullMatched() {
return 1.0 == similarity;
}
} }

View File

@@ -1,6 +1,8 @@
package com.tencent.supersonic.headless.api.pojo; package com.tencent.supersonic.headless.api.pojo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import org.apache.commons.collections4.CollectionUtils;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@@ -26,4 +28,25 @@ public class SchemaMapInfo {
public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) { public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) {
dataSetElementMatches.put(dataSet, elementMatches); dataSetElementMatches.put(dataSet, elementMatches);
} }
@JsonIgnore
public List<SchemaElement> getTermDescriptionToMap() {
List<SchemaElement> termElements = Lists.newArrayList();
for (Long dataSetId : getDataSetElementMatches().keySet()) {
List<SchemaElementMatch> matchedElements = getMatchedElements(dataSetId);
for (SchemaElementMatch schemaElementMatch : matchedElements) {
if (SchemaElementType.TERM.equals(schemaElementMatch.getElement().getType())
&& schemaElementMatch.isFullMatched()
&& !schemaElementMatch.getElement().isDescriptionMapped()) {
termElements.add(schemaElementMatch.getElement());
}
}
}
return termElements;
}
public boolean needContinueMap() {
return CollectionUtils.isNotEmpty(getTermDescriptionToMap());
}
} }

View File

@@ -35,6 +35,7 @@ import java.util.stream.Collectors;
public class ChatQueryContext { public class ChatQueryContext {
private String queryText; private String queryText;
private String oriQueryText;
private Set<Long> dataSetIds; private Set<Long> dataSetIds;
private Map<Long, List<Long>> modelIdToDataSetIds; private Map<Long, List<Long>> modelIdToDataSetIds;
private User user; private User user;

View File

@@ -0,0 +1,37 @@
package com.tencent.supersonic.headless.chat.mapper;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import java.util.List;
/***
* A mapper that map the description of the term.
*/
@Slf4j
public class TermDescMapper extends BaseMapper {
@Override
public void doMap(ChatQueryContext chatQueryContext) {
List<SchemaElement> termDescriptionToMap = chatQueryContext.getMapInfo().getTermDescriptionToMap();
if (CollectionUtils.isEmpty(termDescriptionToMap)) {
if (StringUtils.isNotBlank(chatQueryContext.getOriQueryText())) {
chatQueryContext.setQueryText(chatQueryContext.getOriQueryText());
}
return;
}
if (StringUtils.isBlank(chatQueryContext.getOriQueryText())) {
chatQueryContext.setOriQueryText(chatQueryContext.getQueryText());
}
for (SchemaElement schemaElement : termDescriptionToMap) {
if (chatQueryContext.getQueryText().equals(schemaElement.getDescription())) {
schemaElement.setDescriptionMapped(true);
continue;
}
chatQueryContext.setQueryText(schemaElement.getDescription());
}
}
}

View File

@@ -48,6 +48,8 @@ public class ChatWorkflowEngine {
parseResult.setState(ParseResp.ParseState.FAILED); parseResult.setState(ParseResp.ParseState.FAILED);
parseResult.setErrorMsg("No semantic entities can be mapped against user question."); parseResult.setErrorMsg("No semantic entities can be mapped against user question.");
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
} else if (queryCtx.getMapInfo().needContinueMap()) {
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
} else { } else {
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING); queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING);
} }
@@ -89,7 +91,8 @@ public class ChatWorkflowEngine {
private void performMapping(ChatQueryContext queryCtx) { private void performMapping(ChatQueryContext queryCtx) {
if (Objects.isNull(queryCtx.getMapInfo()) if (Objects.isNull(queryCtx.getMapInfo())
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) { || MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())
|| queryCtx.getMapInfo().needContinueMap()) {
schemaMappers.forEach(mapper -> mapper.map(queryCtx)); schemaMappers.forEach(mapper -> mapper.map(queryCtx));
} }
} }

View File

@@ -4,7 +4,8 @@ com.tencent.supersonic.headless.chat.mapper.SchemaMapper=\
com.tencent.supersonic.headless.chat.mapper.EmbeddingMapper, \ com.tencent.supersonic.headless.chat.mapper.EmbeddingMapper, \
com.tencent.supersonic.headless.chat.mapper.KeywordMapper, \ com.tencent.supersonic.headless.chat.mapper.KeywordMapper, \
com.tencent.supersonic.headless.chat.mapper.QueryFilterMapper, \ com.tencent.supersonic.headless.chat.mapper.QueryFilterMapper, \
com.tencent.supersonic.headless.chat.mapper.EntityMapper com.tencent.supersonic.headless.chat.mapper.EntityMapper, \
com.tencent.supersonic.headless.chat.mapper.TermDescMapper
com.tencent.supersonic.headless.chat.parser.SemanticParser=\ com.tencent.supersonic.headless.chat.parser.SemanticParser=\
com.tencent.supersonic.headless.chat.parser.rule.RuleSqlParser, \ com.tencent.supersonic.headless.chat.parser.rule.RuleSqlParser, \