mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement](chat) Refactor the code for the Map phase (#1646)
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.chat.knowledge;
|
package com.tencent.supersonic.headless.chat.knowledge;
|
||||||
|
|
||||||
import com.google.common.base.Objects;
|
import com.google.common.base.Objects;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
@@ -27,4 +28,13 @@ public class DatabaseMapResult extends MapResult {
|
|||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hashCode(name, schemaElement);
|
return Objects.hashCode(name, schemaElement);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMapKey() {
|
||||||
|
return this.getName()
|
||||||
|
+ Constants.UNDERLINE
|
||||||
|
+ this.getSchemaElement().getId()
|
||||||
|
+ Constants.UNDERLINE
|
||||||
|
+ this.getSchemaElement().getName();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.chat.knowledge;
|
package com.tencent.supersonic.headless.chat.knowledge;
|
||||||
|
|
||||||
import com.google.common.base.Objects;
|
import com.google.common.base.Objects;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
@@ -32,4 +33,9 @@ public class EmbeddingResult extends MapResult {
|
|||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hashCode(id);
|
return Objects.hashCode(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMapKey() {
|
||||||
|
return this.getName() + Constants.UNDERLINE + this.getId();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.chat.knowledge;
|
package com.tencent.supersonic.headless.chat.knowledge;
|
||||||
|
|
||||||
import com.google.common.base.Objects;
|
import com.google.common.base.Objects;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.ToString;
|
import lombok.ToString;
|
||||||
|
|
||||||
@@ -42,4 +43,11 @@ public class HanlpMapResult extends MapResult {
|
|||||||
public void setOffset(int offset) {
|
public void setOffset(int offset) {
|
||||||
this.offset = offset;
|
this.offset = offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String getMapKey() {
|
||||||
|
return this.getName()
|
||||||
|
+ Constants.UNDERLINE
|
||||||
|
+ String.join(Constants.UNDERLINE, this.getNatures());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,8 +7,19 @@ import java.io.Serializable;
|
|||||||
|
|
||||||
@Data
|
@Data
|
||||||
@ToString
|
@ToString
|
||||||
public class MapResult implements Serializable {
|
public abstract class MapResult implements Serializable {
|
||||||
|
|
||||||
protected String name;
|
protected String name;
|
||||||
|
|
||||||
|
protected int index;
|
||||||
protected String detectWord;
|
protected String detectWord;
|
||||||
|
|
||||||
|
public abstract String getMapKey();
|
||||||
|
|
||||||
|
public Boolean lessSimilar(MapResult otherResult) {
|
||||||
|
String mapKey = this.getMapKey();
|
||||||
|
String otherMapKey = otherResult.getMapKey();
|
||||||
|
return mapKey.equals(otherMapKey)
|
||||||
|
&& otherResult.getDetectWord().length() < otherResult.getDetectWord().length();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ import java.util.Arrays;
|
|||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
/** HanLP helper */
|
/** HanLP helper */
|
||||||
@@ -87,7 +89,7 @@ public class HanlpHelper {
|
|||||||
return CustomDictionary;
|
return CustomDictionary;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** * reload custom dictionary */
|
/** reload custom dictionary */
|
||||||
public static boolean reloadCustomDictionary() throws IOException {
|
public static boolean reloadCustomDictionary() throws IOException {
|
||||||
|
|
||||||
final long startTime = System.currentTimeMillis();
|
final long startTime = System.currentTimeMillis();
|
||||||
@@ -316,6 +318,28 @@ public class HanlpHelper {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static List<S2Term> getTerms(List<S2Term> terms, Set<Long> dataSetIds) {
|
||||||
|
logTerms(terms);
|
||||||
|
if (!CollectionUtils.isEmpty(dataSetIds)) {
|
||||||
|
terms =
|
||||||
|
terms.stream()
|
||||||
|
.filter(
|
||||||
|
term -> {
|
||||||
|
Long dataSetId =
|
||||||
|
NatureHelper.getDataSetId(
|
||||||
|
term.getNature().toString());
|
||||||
|
if (Objects.nonNull(dataSetId)) {
|
||||||
|
return dataSetIds.contains(dataSetId);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
})
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
log.debug("terms filter by dataSetId:{}", dataSetIds);
|
||||||
|
logTerms(terms);
|
||||||
|
}
|
||||||
|
return terms;
|
||||||
|
}
|
||||||
|
|
||||||
public static List<S2Term> transform2ApiTerm(
|
public static List<S2Term> transform2ApiTerm(
|
||||||
Term term, Map<Long, List<Long>> modelIdToDataSetIds) {
|
Term term, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||||
List<S2Term> s2Terms = Lists.newArrayList();
|
List<S2Term> s2Terms = Lists.newArrayList();
|
||||||
@@ -331,4 +355,17 @@ public class HanlpHelper {
|
|||||||
}
|
}
|
||||||
return s2Terms;
|
return s2Terms;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static void logTerms(List<S2Term> terms) {
|
||||||
|
if (CollectionUtils.isEmpty(terms)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (S2Term term : terms) {
|
||||||
|
log.debug(
|
||||||
|
"word:{},nature:{},frequency:{}",
|
||||||
|
term.word,
|
||||||
|
term.nature.toString(),
|
||||||
|
term.getFrequency());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,20 +6,20 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Optional;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
import java.util.function.Predicate;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -50,85 +50,6 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
chatQueryContext.getMapInfo().getDataSetElementMatches());
|
chatQueryContext.getMapInfo().getDataSetElementMatches());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void filter(ChatQueryContext chatQueryContext) {
|
|
||||||
filterByDataSetId(chatQueryContext);
|
|
||||||
filterByDetectWordLenLessThanOne(chatQueryContext);
|
|
||||||
switch (chatQueryContext.getQueryDataType()) {
|
|
||||||
case TAG:
|
|
||||||
filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0));
|
|
||||||
break;
|
|
||||||
case METRIC:
|
|
||||||
filterByQueryDataType(
|
|
||||||
chatQueryContext,
|
|
||||||
element -> !SchemaElementType.METRIC.equals(element.getType()));
|
|
||||||
break;
|
|
||||||
case DIMENSION:
|
|
||||||
filterByQueryDataType(
|
|
||||||
chatQueryContext,
|
|
||||||
element -> {
|
|
||||||
boolean isDimensionOrValue =
|
|
||||||
SchemaElementType.DIMENSION.equals(element.getType())
|
|
||||||
|| SchemaElementType.VALUE.equals(element.getType());
|
|
||||||
return !isDimensionOrValue;
|
|
||||||
});
|
|
||||||
break;
|
|
||||||
case ALL:
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static void filterByDataSetId(ChatQueryContext chatQueryContext) {
|
|
||||||
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
|
|
||||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Set<Long> dataSetIdInMapInfo =
|
|
||||||
new HashSet<>(chatQueryContext.getMapInfo().getDataSetElementMatches().keySet());
|
|
||||||
for (Long dataSetId : dataSetIdInMapInfo) {
|
|
||||||
if (!dataSetIds.contains(dataSetId)) {
|
|
||||||
chatQueryContext.getMapInfo().getDataSetElementMatches().remove(dataSetId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static void filterByDetectWordLenLessThanOne(ChatQueryContext chatQueryContext) {
|
|
||||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
|
||||||
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
|
||||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
|
|
||||||
List<SchemaElementMatch> value = entry.getValue();
|
|
||||||
if (!CollectionUtils.isEmpty(value)) {
|
|
||||||
value.removeIf(
|
|
||||||
schemaElementMatch ->
|
|
||||||
StringUtils.length(schemaElementMatch.getDetectWord()) <= 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private static void filterByQueryDataType(
|
|
||||||
ChatQueryContext chatQueryContext, Predicate<SchemaElement> needRemovePredicate) {
|
|
||||||
chatQueryContext
|
|
||||||
.getMapInfo()
|
|
||||||
.getDataSetElementMatches()
|
|
||||||
.values()
|
|
||||||
.forEach(
|
|
||||||
schemaElementMatches -> {
|
|
||||||
schemaElementMatches.removeIf(
|
|
||||||
schemaElementMatch -> {
|
|
||||||
SchemaElement element = schemaElementMatch.getElement();
|
|
||||||
SchemaElementType type = element.getType();
|
|
||||||
|
|
||||||
boolean isEntityOrDatasetOrId =
|
|
||||||
SchemaElementType.ENTITY.equals(type)
|
|
||||||
|| SchemaElementType.DATASET.equals(type)
|
|
||||||
|| SchemaElementType.ID.equals(type);
|
|
||||||
|
|
||||||
return !isEntityOrDatasetOrId
|
|
||||||
&& needRemovePredicate.test(element);
|
|
||||||
});
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract void doMap(ChatQueryContext chatQueryContext);
|
public abstract void doMap(ChatQueryContext chatQueryContext);
|
||||||
|
|
||||||
public void addToSchemaMap(
|
public void addToSchemaMap(
|
||||||
@@ -202,4 +123,57 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
}
|
}
|
||||||
return element.getAlias();
|
return element.getAlias();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public <T> List<T> getMatches(
|
||||||
|
ChatQueryContext chatQueryContext, BaseMatchStrategy matchStrategy) {
|
||||||
|
String queryText = chatQueryContext.getQueryText();
|
||||||
|
List<S2Term> terms =
|
||||||
|
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
||||||
|
terms = HanlpHelper.getTerms(terms, chatQueryContext.getDataSetIds());
|
||||||
|
Map<MatchText, List<T>> matchResult =
|
||||||
|
matchStrategy.match(chatQueryContext, terms, chatQueryContext.getDataSetIds());
|
||||||
|
List<T> matches = new ArrayList<>();
|
||||||
|
if (Objects.isNull(matchResult)) {
|
||||||
|
return matches;
|
||||||
|
}
|
||||||
|
Optional<List<T>> first =
|
||||||
|
matchResult.entrySet().stream()
|
||||||
|
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||||
|
.map(entry -> entry.getValue())
|
||||||
|
.findFirst();
|
||||||
|
|
||||||
|
if (first.isPresent()) {
|
||||||
|
matches = first.get();
|
||||||
|
}
|
||||||
|
return matches;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void filter(ChatQueryContext chatQueryContext) {
|
||||||
|
MapFilter.filterByDataSetId(chatQueryContext);
|
||||||
|
MapFilter.filterByDetectWordLenLessThanOne(chatQueryContext);
|
||||||
|
switch (chatQueryContext.getQueryDataType()) {
|
||||||
|
case TAG:
|
||||||
|
MapFilter.filterByQueryDataType(
|
||||||
|
chatQueryContext, element -> !(element.getIsTag() > 0));
|
||||||
|
break;
|
||||||
|
case METRIC:
|
||||||
|
MapFilter.filterByQueryDataType(
|
||||||
|
chatQueryContext,
|
||||||
|
element -> !SchemaElementType.METRIC.equals(element.getType()));
|
||||||
|
break;
|
||||||
|
case DIMENSION:
|
||||||
|
MapFilter.filterByQueryDataType(
|
||||||
|
chatQueryContext,
|
||||||
|
element -> {
|
||||||
|
boolean isDimensionOrValue =
|
||||||
|
SchemaElementType.DIMENSION.equals(element.getType())
|
||||||
|
|| SchemaElementType.VALUE.equals(element.getType());
|
||||||
|
return !isDimensionOrValue;
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
case ALL:
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,32 +3,21 @@ package com.tencent.supersonic.headless.chat.mapper;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
import com.tencent.supersonic.headless.chat.knowledge.MapResult;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> {
|
||||||
|
|
||||||
@Autowired protected MapperHelper mapperHelper;
|
|
||||||
|
|
||||||
@Autowired protected MapperConfig mapperConfig;
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<T>> match(
|
public Map<MatchText, List<T>> match(
|
||||||
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||||
@@ -48,37 +37,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
|
|
||||||
public List<T> detect(
|
public List<T> detect(
|
||||||
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
throw new RuntimeException("Not implemented");
|
||||||
String text = chatQueryContext.getQueryText();
|
|
||||||
Set<T> results = new HashSet<>();
|
|
||||||
|
|
||||||
Set<String> detectSegments = new HashSet<>();
|
|
||||||
|
|
||||||
for (Integer startIndex = 0; startIndex <= text.length() - 1; ) {
|
|
||||||
|
|
||||||
for (Integer index = startIndex; index <= text.length(); ) {
|
|
||||||
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
|
||||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
|
||||||
if (index <= text.length()) {
|
|
||||||
String detectSegment = text.substring(startIndex, index).trim();
|
|
||||||
detectSegments.add(detectSegment);
|
|
||||||
detectByStep(
|
|
||||||
chatQueryContext, results, detectDataSetIds, detectSegment, offset);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
|
||||||
}
|
|
||||||
return new ArrayList<>(results);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
|
|
||||||
return terms.stream()
|
|
||||||
.sorted(Comparator.comparing(S2Term::length))
|
|
||||||
.collect(
|
|
||||||
Collectors.toMap(
|
|
||||||
S2Term::getOffset,
|
|
||||||
term -> term.word.length(),
|
|
||||||
(value1, value2) -> value2));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
||||||
@@ -90,7 +49,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
boolean isDeleted =
|
boolean isDeleted =
|
||||||
existResults.removeIf(
|
existResults.removeIf(
|
||||||
existResult -> {
|
existResult -> {
|
||||||
boolean delete = needDelete(oneRoundResult, existResult);
|
boolean delete = oneRoundResult.lessSimilar(existResult);
|
||||||
if (delete) {
|
if (delete) {
|
||||||
log.info("deleted existResult:{}", existResult);
|
log.info("deleted existResult:{}", existResult);
|
||||||
}
|
}
|
||||||
@@ -106,72 +65,6 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<T> getMatches(ChatQueryContext chatQueryContext, List<S2Term> terms) {
|
|
||||||
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
|
|
||||||
terms = filterByDataSetId(terms, dataSetIds);
|
|
||||||
Map<MatchText, List<T>> matchResult = match(chatQueryContext, terms, dataSetIds);
|
|
||||||
List<T> matches = new ArrayList<>();
|
|
||||||
if (Objects.isNull(matchResult)) {
|
|
||||||
return matches;
|
|
||||||
}
|
|
||||||
Optional<List<T>> first =
|
|
||||||
matchResult.entrySet().stream()
|
|
||||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
|
||||||
.map(entry -> entry.getValue())
|
|
||||||
.findFirst();
|
|
||||||
|
|
||||||
if (first.isPresent()) {
|
|
||||||
matches = first.get();
|
|
||||||
}
|
|
||||||
return matches;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<S2Term> filterByDataSetId(List<S2Term> terms, Set<Long> dataSetIds) {
|
|
||||||
logTerms(terms);
|
|
||||||
if (CollectionUtils.isNotEmpty(dataSetIds)) {
|
|
||||||
terms =
|
|
||||||
terms.stream()
|
|
||||||
.filter(
|
|
||||||
term -> {
|
|
||||||
Long dataSetId =
|
|
||||||
NatureHelper.getDataSetId(
|
|
||||||
term.getNature().toString());
|
|
||||||
if (Objects.nonNull(dataSetId)) {
|
|
||||||
return dataSetIds.contains(dataSetId);
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
})
|
|
||||||
.collect(Collectors.toList());
|
|
||||||
log.debug("terms filter by dataSetId:{}", dataSetIds);
|
|
||||||
logTerms(terms);
|
|
||||||
}
|
|
||||||
return terms;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void logTerms(List<S2Term> terms) {
|
|
||||||
if (CollectionUtils.isEmpty(terms)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
for (S2Term term : terms) {
|
|
||||||
log.debug(
|
|
||||||
"word:{},nature:{},frequency:{}",
|
|
||||||
term.word,
|
|
||||||
term.nature.toString(),
|
|
||||||
term.getFrequency());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract boolean needDelete(T oneRoundResult, T existResult);
|
|
||||||
|
|
||||||
public abstract String getMapKey(T a);
|
|
||||||
|
|
||||||
public abstract void detectByStep(
|
|
||||||
ChatQueryContext chatQueryContext,
|
|
||||||
Set<T> existResults,
|
|
||||||
Set<Long> detectDataSetIds,
|
|
||||||
String detectSegment,
|
|
||||||
int offset);
|
|
||||||
|
|
||||||
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
||||||
double decreaseAmount = (threshold - minThreshold) / 4;
|
double decreaseAmount = (threshold - minThreshold) / 4;
|
||||||
double divideThreshold = threshold - mapModeEnum.threshold * decreaseAmount;
|
double divideThreshold = threshold - mapModeEnum.threshold * decreaseAmount;
|
||||||
|
|||||||
@@ -0,0 +1,47 @@
|
|||||||
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
import com.tencent.supersonic.headless.chat.knowledge.MapResult;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public abstract class BatchMatchStrategy<T extends MapResult> extends BaseMatchStrategy<T> {
|
||||||
|
|
||||||
|
@Autowired protected MapperConfig mapperConfig;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<T> detect(
|
||||||
|
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||||
|
|
||||||
|
String text = chatQueryContext.getQueryText();
|
||||||
|
Set<String> detectSegments = new HashSet<>();
|
||||||
|
|
||||||
|
int embeddingTextSize =
|
||||||
|
Integer.valueOf(
|
||||||
|
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE));
|
||||||
|
|
||||||
|
int embeddingTextStep =
|
||||||
|
Integer.valueOf(
|
||||||
|
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP));
|
||||||
|
|
||||||
|
for (int startIndex = 0; startIndex < text.length(); startIndex += embeddingTextStep) {
|
||||||
|
int endIndex = Math.min(startIndex + embeddingTextSize, text.length());
|
||||||
|
String detectSegment = text.substring(startIndex, endIndex).trim();
|
||||||
|
detectSegments.add(detectSegment);
|
||||||
|
}
|
||||||
|
return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract List<T> detectByBatch(
|
||||||
|
ChatQueryContext chatQueryContext,
|
||||||
|
Set<Long> detectDataSetIds,
|
||||||
|
Set<String> detectSegments);
|
||||||
|
}
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
package com.tencent.supersonic.headless.chat.mapper;
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
@@ -25,7 +24,7 @@ import java.util.stream.Collectors;
|
|||||||
*/
|
*/
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult> {
|
public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult> {
|
||||||
|
|
||||||
private List<SchemaElement> allElements;
|
private List<SchemaElement> allElements;
|
||||||
|
|
||||||
@@ -36,34 +35,18 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
return super.match(chatQueryContext, terms, detectDataSetIds);
|
return super.match(chatQueryContext, terms, detectDataSetIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
public List<DatabaseMapResult> detectByStep(
|
||||||
public boolean needDelete(DatabaseMapResult oneRoundResult, DatabaseMapResult existResult) {
|
|
||||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
|
||||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getMapKey(DatabaseMapResult a) {
|
|
||||||
return a.getName()
|
|
||||||
+ Constants.UNDERLINE
|
|
||||||
+ a.getSchemaElement().getId()
|
|
||||||
+ Constants.UNDERLINE
|
|
||||||
+ a.getSchemaElement().getName();
|
|
||||||
}
|
|
||||||
|
|
||||||
public void detectByStep(
|
|
||||||
ChatQueryContext chatQueryContext,
|
ChatQueryContext chatQueryContext,
|
||||||
Set<DatabaseMapResult> existResults,
|
|
||||||
Set<Long> detectDataSetIds,
|
Set<Long> detectDataSetIds,
|
||||||
String detectSegment,
|
String detectSegment,
|
||||||
int offset) {
|
int offset) {
|
||||||
if (StringUtils.isBlank(detectSegment)) {
|
if (StringUtils.isBlank(detectSegment)) {
|
||||||
return;
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Double metricDimensionThresholdConfig = getThreshold(chatQueryContext);
|
Double metricDimensionThresholdConfig = getThreshold(chatQueryContext);
|
||||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||||
|
List<DatabaseMapResult> results = new ArrayList<>();
|
||||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||||
String name = entry.getKey();
|
String name = entry.getKey();
|
||||||
if (!name.contains(detectSegment)
|
if (!name.contains(detectSegment)
|
||||||
@@ -86,9 +69,10 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
databaseMapResult.setDetectWord(detectSegment);
|
databaseMapResult.setDetectWord(detectSegment);
|
||||||
databaseMapResult.setName(schemaElement.getName());
|
databaseMapResult.setName(schemaElement.getName());
|
||||||
databaseMapResult.setSchemaElement(schemaElement);
|
databaseMapResult.setSchemaElement(schemaElement);
|
||||||
existResults.add(databaseMapResult);
|
results.add(databaseMapResult);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
private List<SchemaElement> getSchemaElements(ChatQueryContext chatQueryContext) {
|
private List<SchemaElement> getSchemaElements(ChatQueryContext chatQueryContext) {
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import com.tencent.supersonic.common.util.ContextUtils;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
||||||
@@ -15,19 +14,15 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
|
||||||
/** * A mapper that recognizes schema elements with vector embedding. */
|
/** A mapper that recognizes schema elements with vector embedding. */
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingMapper extends BaseMapper {
|
public class EmbeddingMapper extends BaseMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void doMap(ChatQueryContext chatQueryContext) {
|
public void doMap(ChatQueryContext chatQueryContext) {
|
||||||
// 1. query from embedding by queryText
|
// 1. query from embedding by queryText
|
||||||
String queryText = chatQueryContext.getQueryText();
|
|
||||||
List<S2Term> terms =
|
|
||||||
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
|
||||||
|
|
||||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||||
List<EmbeddingResult> matchResults = matchStrategy.getMatches(chatQueryContext, terms);
|
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
|
||||||
|
|
||||||
HanlpHelper.transLetterOriginal(matchResults);
|
HanlpHelper.transLetterOriginal(matchResults);
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package com.tencent.supersonic.headless.chat.mapper;
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
||||||
@@ -35,54 +33,12 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING
|
|||||||
*/
|
*/
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult> {
|
||||||
|
|
||||||
@Autowired private MetaEmbeddingService metaEmbeddingService;
|
@Autowired private MetaEmbeddingService metaEmbeddingService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
public List<EmbeddingResult> detectByBatch(
|
||||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
|
||||||
&& existResult.getDistance() > oneRoundResult.getDistance();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getMapKey(EmbeddingResult a) {
|
|
||||||
return a.getName() + Constants.UNDERLINE + a.getId();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void detectByStep(
|
|
||||||
ChatQueryContext chatQueryContext,
|
|
||||||
Set<EmbeddingResult> existResults,
|
|
||||||
Set<Long> detectDataSetIds,
|
|
||||||
String detectSegment,
|
|
||||||
int offset) {}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<EmbeddingResult> detect(
|
|
||||||
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
|
||||||
String text = chatQueryContext.getQueryText();
|
|
||||||
Set<String> detectSegments = new HashSet<>();
|
|
||||||
|
|
||||||
int embeddingTextSize =
|
|
||||||
Integer.valueOf(
|
|
||||||
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE));
|
|
||||||
|
|
||||||
int embeddingTextStep =
|
|
||||||
Integer.valueOf(
|
|
||||||
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP));
|
|
||||||
|
|
||||||
for (int startIndex = 0; startIndex < text.length(); startIndex += embeddingTextStep) {
|
|
||||||
int endIndex = Math.min(startIndex + embeddingTextSize, text.length());
|
|
||||||
String detectSegment = text.substring(startIndex, endIndex).trim();
|
|
||||||
detectSegments.add(detectSegment);
|
|
||||||
}
|
|
||||||
Set<EmbeddingResult> results =
|
|
||||||
detectByBatch(chatQueryContext, detectDataSetIds, detectSegments);
|
|
||||||
return new ArrayList<>(results);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected Set<EmbeddingResult> detectByBatch(
|
|
||||||
ChatQueryContext chatQueryContext,
|
ChatQueryContext chatQueryContext,
|
||||||
Set<Long> detectDataSetIds,
|
Set<Long> detectDataSetIds,
|
||||||
Set<String> detectSegments) {
|
Set<String> detectSegments) {
|
||||||
@@ -103,7 +59,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext);
|
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext);
|
||||||
}
|
}
|
||||||
return results;
|
return new ArrayList<>(results);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void detectByQueryTextsSub(
|
private void detectByQueryTextsSub(
|
||||||
|
|||||||
@@ -1,22 +1,16 @@
|
|||||||
package com.tencent.supersonic.headless.chat.mapper;
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
||||||
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
|
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.LinkedHashSet;
|
import java.util.LinkedHashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@@ -30,36 +24,12 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.MAPPER_DI
|
|||||||
*/
|
*/
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
public class HanlpDictMatchStrategy extends SingleMatchStrategy<HanlpMapResult> {
|
||||||
|
|
||||||
@Autowired private KnowledgeBaseService knowledgeBaseService;
|
@Autowired private KnowledgeBaseService knowledgeBaseService;
|
||||||
|
|
||||||
@Override
|
public List<HanlpMapResult> detectByStep(
|
||||||
public Map<MatchText, List<HanlpMapResult>> match(
|
|
||||||
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
|
||||||
String text = chatQueryContext.getQueryText();
|
|
||||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
log.debug("terms:{},detectModelIds:{}", terms, detectDataSetIds);
|
|
||||||
|
|
||||||
List<HanlpMapResult> detects = detect(chatQueryContext, terms, detectDataSetIds);
|
|
||||||
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
|
||||||
|
|
||||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
|
||||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
|
||||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
|
||||||
}
|
|
||||||
|
|
||||||
public void detectByStep(
|
|
||||||
ChatQueryContext chatQueryContext,
|
ChatQueryContext chatQueryContext,
|
||||||
Set<HanlpMapResult> existResults,
|
|
||||||
Set<Long> detectDataSetIds,
|
Set<Long> detectDataSetIds,
|
||||||
String detectSegment,
|
String detectSegment,
|
||||||
int offset) {
|
int offset) {
|
||||||
@@ -89,7 +59,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||||
|
|
||||||
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
||||||
return;
|
return new ArrayList<>();
|
||||||
}
|
}
|
||||||
// step3. merge pre/suffix result
|
// step3. merge pre/suffix result
|
||||||
hanlpMapResults =
|
hanlpMapResults =
|
||||||
@@ -155,12 +125,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
oneRoundResults.addAll(additionalResults);
|
oneRoundResults.addAll(additionalResults);
|
||||||
}
|
}
|
||||||
// step6. select mapResul in one round
|
return oneRoundResults;
|
||||||
selectResultInOneRound(existResults, oneRoundResults);
|
|
||||||
}
|
|
||||||
|
|
||||||
public String getMapKey(HanlpMapResult a) {
|
|
||||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getThresholdMatch(List<String> natures, ChatQueryContext chatQueryContext) {
|
public double getThresholdMatch(List<String> natures, ChatQueryContext chatQueryContext) {
|
||||||
|
|||||||
@@ -38,16 +38,15 @@ public class KeywordMapper extends BaseMapper {
|
|||||||
HanlpDictMatchStrategy hanlpMatchStrategy =
|
HanlpDictMatchStrategy hanlpMatchStrategy =
|
||||||
ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||||
|
|
||||||
List<HanlpMapResult> hanlpMapResults =
|
List<HanlpMapResult> matchResults = getMatches(chatQueryContext, hanlpMatchStrategy);
|
||||||
hanlpMatchStrategy.getMatches(chatQueryContext, terms);
|
|
||||||
convertHanlpMapResultToMapInfo(hanlpMapResults, chatQueryContext, terms);
|
convertHanlpMapResultToMapInfo(matchResults, chatQueryContext, terms);
|
||||||
|
|
||||||
// 2.database Match
|
// 2.database Match
|
||||||
DatabaseMatchStrategy databaseMatchStrategy =
|
DatabaseMatchStrategy databaseMatchStrategy =
|
||||||
ContextUtils.getBean(DatabaseMatchStrategy.class);
|
ContextUtils.getBean(DatabaseMatchStrategy.class);
|
||||||
|
|
||||||
List<DatabaseMapResult> databaseResults =
|
List<DatabaseMapResult> databaseResults =
|
||||||
databaseMatchStrategy.getMatches(chatQueryContext, terms);
|
getMatches(chatQueryContext, databaseMatchStrategy);
|
||||||
convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults);
|
convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,68 @@
|
|||||||
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.function.Predicate;
|
||||||
|
|
||||||
|
public class MapFilter {
|
||||||
|
|
||||||
|
public static void filterByDataSetId(ChatQueryContext chatQueryContext) {
|
||||||
|
Set<Long> dataSetIds = chatQueryContext.getDataSetIds();
|
||||||
|
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Set<Long> dataSetIdInMapInfo =
|
||||||
|
new HashSet<>(chatQueryContext.getMapInfo().getDataSetElementMatches().keySet());
|
||||||
|
for (Long dataSetId : dataSetIdInMapInfo) {
|
||||||
|
if (!dataSetIds.contains(dataSetId)) {
|
||||||
|
chatQueryContext.getMapInfo().getDataSetElementMatches().remove(dataSetId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void filterByDetectWordLenLessThanOne(ChatQueryContext chatQueryContext) {
|
||||||
|
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
||||||
|
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||||
|
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
|
||||||
|
List<SchemaElementMatch> value = entry.getValue();
|
||||||
|
if (!CollectionUtils.isEmpty(value)) {
|
||||||
|
value.removeIf(
|
||||||
|
schemaElementMatch ->
|
||||||
|
StringUtils.length(schemaElementMatch.getDetectWord()) <= 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void filterByQueryDataType(
|
||||||
|
ChatQueryContext chatQueryContext, Predicate<SchemaElement> needRemovePredicate) {
|
||||||
|
chatQueryContext
|
||||||
|
.getMapInfo()
|
||||||
|
.getDataSetElementMatches()
|
||||||
|
.values()
|
||||||
|
.forEach(
|
||||||
|
schemaElementMatches -> {
|
||||||
|
schemaElementMatches.removeIf(
|
||||||
|
schemaElementMatch -> {
|
||||||
|
SchemaElement element = schemaElementMatch.getElement();
|
||||||
|
SchemaElementType type = element.getType();
|
||||||
|
|
||||||
|
boolean isEntityOrDatasetOrId =
|
||||||
|
SchemaElementType.ENTITY.equals(type)
|
||||||
|
|| SchemaElementType.DATASET.equals(type)
|
||||||
|
|| SchemaElementType.ID.equals(type);
|
||||||
|
|
||||||
|
return !isEntityOrDatasetOrId
|
||||||
|
&& needRemovePredicate.test(element);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -43,6 +43,16 @@ public class MapperHelper {
|
|||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
|
||||||
|
return terms.stream()
|
||||||
|
.sorted(Comparator.comparing(S2Term::length))
|
||||||
|
.collect(
|
||||||
|
Collectors.toMap(
|
||||||
|
S2Term::getOffset,
|
||||||
|
term -> term.word.length(),
|
||||||
|
(value1, value2) -> value2));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* * exist dimension values
|
* * exist dimension values
|
||||||
*
|
*
|
||||||
@@ -58,15 +68,6 @@ public class MapperHelper {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean existTerms(List<String> natures) {
|
|
||||||
for (String nature : natures) {
|
|
||||||
if (NatureHelper.isTermNature(nature)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* * get similarity
|
* * get similarity
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.mapper;
|
|||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
import com.tencent.supersonic.headless.chat.knowledge.MapResult;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -10,7 +11,7 @@ import java.util.Set;
|
|||||||
/**
|
/**
|
||||||
* MatchStrategy encapsulates a concrete matching algorithm executed during query or search process.
|
* MatchStrategy encapsulates a concrete matching algorithm executed during query or search process.
|
||||||
*/
|
*/
|
||||||
public interface MatchStrategy<T> {
|
public interface MatchStrategy<T extends MapResult> {
|
||||||
|
|
||||||
Map<MatchText, List<T>> match(
|
Map<MatchText, List<T>> match(
|
||||||
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
|
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
|
||||||
|
|||||||
@@ -29,11 +29,13 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
|
|
||||||
@Autowired private KnowledgeBaseService knowledgeBaseService;
|
@Autowired private KnowledgeBaseService knowledgeBaseService;
|
||||||
|
|
||||||
|
@Autowired private MapperHelper mapperHelper;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<HanlpMapResult>> match(
|
public Map<MatchText, List<HanlpMapResult>> match(
|
||||||
ChatQueryContext chatQueryContext, List<S2Term> originals, Set<Long> detectDataSetIds) {
|
ChatQueryContext chatQueryContext, List<S2Term> originals, Set<Long> detectDataSetIds) {
|
||||||
String text = chatQueryContext.getQueryText();
|
String text = chatQueryContext.getQueryText();
|
||||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(originals);
|
||||||
|
|
||||||
List<Integer> detectIndexList = Lists.newArrayList();
|
List<Integer> detectIndexList = Lists.newArrayList();
|
||||||
|
|
||||||
@@ -104,22 +106,4 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
});
|
});
|
||||||
return regTextMap;
|
return regTextMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getMapKey(HanlpMapResult a) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void detectByStep(
|
|
||||||
ChatQueryContext chatQueryContext,
|
|
||||||
Set<HanlpMapResult> existResults,
|
|
||||||
Set<Long> detectDataSetIds,
|
|
||||||
String detectSegment,
|
|
||||||
int offset) {}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
package com.tencent.supersonic.headless.chat.mapper;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
|
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||||
|
import com.tencent.supersonic.headless.chat.knowledge.MapResult;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
@Service
|
||||||
|
@Slf4j
|
||||||
|
public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatchStrategy<T> {
|
||||||
|
@Autowired protected MapperConfig mapperConfig;
|
||||||
|
@Autowired protected MapperHelper mapperHelper;
|
||||||
|
|
||||||
|
public List<T> detect(
|
||||||
|
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||||
|
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
|
||||||
|
String text = chatQueryContext.getQueryText();
|
||||||
|
Set<T> results = new HashSet<>();
|
||||||
|
|
||||||
|
Set<String> detectSegments = new HashSet<>();
|
||||||
|
|
||||||
|
for (Integer startIndex = 0; startIndex <= text.length() - 1; ) {
|
||||||
|
|
||||||
|
for (Integer index = startIndex; index <= text.length(); ) {
|
||||||
|
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||||
|
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||||
|
if (index <= text.length()) {
|
||||||
|
String detectSegment = text.substring(startIndex, index).trim();
|
||||||
|
detectSegments.add(detectSegment);
|
||||||
|
List<T> oneRoundResults =
|
||||||
|
detectByStep(chatQueryContext, detectDataSetIds, detectSegment, offset);
|
||||||
|
selectResultInOneRound(results, oneRoundResults);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||||
|
}
|
||||||
|
return new ArrayList<>(results);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract List<T> detectByStep(
|
||||||
|
ChatQueryContext chatQueryContext,
|
||||||
|
Set<Long> detectDataSetIds,
|
||||||
|
String detectSegment,
|
||||||
|
int offset);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user