[improvement][Headless] Embedding supports Chinese by default and fixes the issue of abnormal number recognition (#726)

This commit is contained in:
lexluo09
2024-02-18 19:51:19 +08:00
committed by GitHub
parent 39158d6877
commit fdb69547e6
19 changed files with 62 additions and 59 deletions

View File

@@ -52,7 +52,7 @@ public class OptimizationConfig {
@Value("${embedding.mapper.round.number:10}")
private int embeddingMapperRoundNumber;
@Value("${embedding.mapper.distance.threshold:0.58}")
@Value("${embedding.mapper.distance.threshold:0.01}")
private Double embeddingMapperDistanceThreshold;
@Value("${s2SQL.linking.value.switch:true}")

View File

@@ -3,12 +3,6 @@ package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
@@ -19,6 +13,11 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@Slf4j
@@ -29,7 +28,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
@Override
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectViewIds) {
Set<Long> detectViewIds) {
String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
@@ -57,9 +56,9 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
int offset = mapperHelper.getStepOffset(terms, startIndex);
index = mapperHelper.getStepIndex(regOffsetToLength, index);
if (index <= text.length()) {
String detectSegment = text.substring(startIndex, index);
String detectSegment = text.substring(startIndex, index).trim();
detectSegments.add(detectSegment);
detectByStep(queryContext, results, detectViewIds, startIndex, index, offset);
detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
}
}
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
@@ -151,7 +150,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
public abstract String getMapKey(T a);
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
Set<Long> detectViewIds, Integer startIndex, Integer index, int offset);
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
String detectSegment, int offset);
}

View File

@@ -55,15 +55,12 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
}
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
Integer startIndex, Integer index, int offset) {
String detectSegment = queryContext.getQueryText().substring(startIndex, index);
String detectSegment, int offset) {
if (StringUtils.isBlank(detectSegment)) {
return;
}
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
Double metricDimensionThresholdConfig = getThreshold(queryContext);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
@@ -73,9 +70,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
continue;
}
Set<SchemaElement> schemaElements = entry.getValue();
if (!CollectionUtils.isEmpty(viewIds)) {
if (!CollectionUtils.isEmpty(detectViewIds)) {
schemaElements = schemaElements.stream()
.filter(schemaElement -> viewIds.contains(schemaElement.getView()))
.filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
.collect(Collectors.toSet());
}
for (SchemaElement schemaElement : schemaElements) {

View File

@@ -1,11 +1,11 @@
package com.tencent.supersonic.chat.core.mapper;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
@@ -34,14 +34,12 @@ public class EmbeddingMapper extends BaseMapper {
//2. build SchemaElementMatch by info
for (EmbeddingResult matchResult : matchResults) {
Long elementId = Retrieval.getLongId(matchResult.getId());
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
SchemaElement.class);
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
if (Objects.isNull(viewId)) {
continue;
}
schemaElement = getSchemaElement(viewId, schemaElement.getType(), elementId,
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
queryContext.getSemanticSchema());
if (schemaElement == null) {
continue;

View File

@@ -2,12 +2,12 @@ package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
import java.util.ArrayList;
import java.util.Comparator;
@@ -47,6 +47,12 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
return a.getName() + Constants.UNDERLINE + a.getId();
}
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
String detectSegment, int offset) {
}
@Override
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
Set<String> detectSegments) {
@@ -111,9 +117,4 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
selectResultInOneRound(results, oneRoundResults);
}
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
Integer startIndex, Integer index, int offset) {
return;
}
}

View File

@@ -58,10 +58,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
}
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
Integer startIndex, Integer index, int offset) {
String text = queryContext.getQueryText();
String detectSegment = text.substring(startIndex, index);
String detectSegment, int offset) {
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,

View File

@@ -88,9 +88,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
}
@Override
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectViewIds,
Integer startIndex,
Integer i, int offset) {
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
String detectSegment, int offset) {
}
}