mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 22:25:19 +00:00
[improvement][Headless] Embedding supports Chinese by default and fixes the issue of abnormal number recognition (#726)
This commit is contained in:
@@ -52,7 +52,7 @@ public class OptimizationConfig {
|
||||
@Value("${embedding.mapper.round.number:10}")
|
||||
private int embeddingMapperRoundNumber;
|
||||
|
||||
@Value("${embedding.mapper.distance.threshold:0.58}")
|
||||
@Value("${embedding.mapper.distance.threshold:0.01}")
|
||||
private Double embeddingMapperDistanceThreshold;
|
||||
|
||||
@Value("${s2SQL.linking.value.switch:true}")
|
||||
|
||||
@@ -3,12 +3,6 @@ package com.tencent.supersonic.chat.core.mapper;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -19,6 +13,11 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -29,7 +28,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectViewIds) {
|
||||
Set<Long> detectViewIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
@@ -57,9 +56,9 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||
if (index <= text.length()) {
|
||||
String detectSegment = text.substring(startIndex, index);
|
||||
String detectSegment = text.substring(startIndex, index).trim();
|
||||
detectSegments.add(detectSegment);
|
||||
detectByStep(queryContext, results, detectViewIds, startIndex, index, offset);
|
||||
detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
|
||||
}
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
@@ -151,7 +150,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
public abstract String getMapKey(T a);
|
||||
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
|
||||
Set<Long> detectViewIds, Integer startIndex, Integer index, int offset);
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
|
||||
String detectSegment, int offset);
|
||||
|
||||
}
|
||||
|
||||
@@ -55,15 +55,12 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
String detectSegment = queryContext.getQueryText().substring(startIndex, index);
|
||||
String detectSegment, int offset) {
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
}
|
||||
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||
|
||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||
|
||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||
@@ -73,9 +70,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
continue;
|
||||
}
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(viewIds)) {
|
||||
if (!CollectionUtils.isEmpty(detectViewIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> viewIds.contains(schemaElement.getView()))
|
||||
.filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||
@@ -34,14 +34,12 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
//2. build SchemaElementMatch by info
|
||||
for (EmbeddingResult matchResult : matchResults) {
|
||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||
|
||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
||||
SchemaElement.class);
|
||||
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
|
||||
if (Objects.isNull(viewId)) {
|
||||
continue;
|
||||
}
|
||||
schemaElement = getSchemaElement(viewId, schemaElement.getType(), elementId,
|
||||
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
|
||||
queryContext.getSemanticSchema());
|
||||
if (schemaElement == null) {
|
||||
continue;
|
||||
|
||||
@@ -2,12 +2,12 @@ package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
@@ -47,6 +47,12 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
return a.getName() + Constants.UNDERLINE + a.getId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
|
||||
String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||
Set<String> detectSegments) {
|
||||
@@ -111,9 +117,4 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
selectResultInOneRound(results, oneRoundResults);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,10 +58,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
String text = queryContext.getQueryText();
|
||||
String detectSegment = text.substring(startIndex, index);
|
||||
|
||||
String detectSegment, int offset) {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
|
||||
|
||||
@@ -88,9 +88,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectViewIds,
|
||||
Integer startIndex,
|
||||
Integer i, int offset) {
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||
String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user