[improvement][Headless] Embedding supports Chinese by default and fixes the issue of abnormal number recognition (#726)

This commit is contained in:
lexluo09
2024-02-18 19:51:19 +08:00
committed by GitHub
parent 39158d6877
commit fdb69547e6
19 changed files with 62 additions and 59 deletions

View File

@@ -52,7 +52,7 @@ public class OptimizationConfig {
@Value("${embedding.mapper.round.number:10}")
private int embeddingMapperRoundNumber;
@Value("${embedding.mapper.distance.threshold:0.58}")
@Value("${embedding.mapper.distance.threshold:0.01}")
private Double embeddingMapperDistanceThreshold;
@Value("${s2SQL.linking.value.switch:true}")

View File

@@ -3,12 +3,6 @@ package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
@@ -19,6 +13,11 @@ import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
@Slf4j
@@ -57,9 +56,9 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
int offset = mapperHelper.getStepOffset(terms, startIndex);
index = mapperHelper.getStepIndex(regOffsetToLength, index);
if (index <= text.length()) {
String detectSegment = text.substring(startIndex, index);
String detectSegment = text.substring(startIndex, index).trim();
detectSegments.add(detectSegment);
detectByStep(queryContext, results, detectViewIds, startIndex, index, offset);
detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
}
}
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
@@ -151,7 +150,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
public abstract String getMapKey(T a);
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
Set<Long> detectViewIds, Integer startIndex, Integer index, int offset);
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
String detectSegment, int offset);
}

View File

@@ -55,15 +55,12 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
}
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
Integer startIndex, Integer index, int offset) {
String detectSegment = queryContext.getQueryText().substring(startIndex, index);
String detectSegment, int offset) {
if (StringUtils.isBlank(detectSegment)) {
return;
}
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
Double metricDimensionThresholdConfig = getThreshold(queryContext);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
@@ -73,9 +70,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
continue;
}
Set<SchemaElement> schemaElements = entry.getValue();
if (!CollectionUtils.isEmpty(viewIds)) {
if (!CollectionUtils.isEmpty(detectViewIds)) {
schemaElements = schemaElements.stream()
.filter(schemaElement -> viewIds.contains(schemaElement.getView()))
.filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
.collect(Collectors.toSet());
}
for (SchemaElement schemaElement : schemaElements) {

View File

@@ -1,11 +1,11 @@
package com.tencent.supersonic.chat.core.mapper;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
@@ -34,14 +34,12 @@ public class EmbeddingMapper extends BaseMapper {
//2. build SchemaElementMatch by info
for (EmbeddingResult matchResult : matchResults) {
Long elementId = Retrieval.getLongId(matchResult.getId());
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
SchemaElement.class);
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
if (Objects.isNull(viewId)) {
continue;
}
schemaElement = getSchemaElement(viewId, schemaElement.getType(), elementId,
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
queryContext.getSemanticSchema());
if (schemaElement == null) {
continue;

View File

@@ -2,12 +2,12 @@ package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
import java.util.ArrayList;
import java.util.Comparator;
@@ -47,6 +47,12 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
return a.getName() + Constants.UNDERLINE + a.getId();
}
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
String detectSegment, int offset) {
}
@Override
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
Set<String> detectSegments) {
@@ -111,9 +117,4 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
selectResultInOneRound(results, oneRoundResults);
}
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
Integer startIndex, Integer index, int offset) {
return;
}
}

View File

@@ -58,10 +58,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
}
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
Integer startIndex, Integer index, int offset) {
String text = queryContext.getQueryText();
String detectSegment = text.substring(startIndex, index);
String detectSegment, int offset) {
// step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,

View File

@@ -88,9 +88,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
}
@Override
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectViewIds,
Integer startIndex,
Integer i, int offset) {
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
String detectSegment, int offset) {
}
}

View File

@@ -32,8 +32,7 @@ public class SchemaDictUpdateListener implements ApplicationListener<DataEvent>
DictWord dictWord = new DictWord();
dictWord.setWord(dataItem.getName());
String sign = DictWordType.NATURE_SPILT;
String nature = sign + 1 + sign + dataItem.getId()
+ sign + dataItem.getType().name().toLowerCase();
String nature = sign + 1 + sign + dataItem.getId() + dataItem.getType().name().toLowerCase();
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
dictWord.setNature(nature);
dictWord.setNatureWithFrequency(natureWithFrequency);

View File

@@ -8,7 +8,10 @@ import lombok.Data;
@Builder
public class DataItem {
private Long id;
/***
* This field uses an underscore (_) at the end.
*/
private String id;
private String bizName;
@@ -18,6 +21,9 @@ public class DataItem {
private TypeEnums type;
/***
* This field uses an underscore (_) at the end.
*/
private String modelId;
private String defaultAgg;

View File

@@ -75,7 +75,7 @@ public class SysParameter {
parameters.add(new Parameter("embedding.mapper.number", "5",
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.distance.threshold",
"0.58", "向量召回相似度阈值", "相似度大于该阈值的则舍弃", "number", "Mapper相关配置"));
"0.01", "向量召回相似度阈值", "相似度大于该阈值的则舍弃", "number", "Mapper相关配置"));
//parser config
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT",

View File

@@ -129,7 +129,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
List<Retrieval> retrievals = new ArrayList<>();
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
Retrieval retrieval = new Retrieval();
retrieval.setDistance(embeddingMatch.score());
retrieval.setDistance(1 - embeddingMatch.score());
retrieval.setId(embeddingMatch.embeddingId());
retrieval.setQuery(embeddingMatch.embedded().getQuery());
Map<String, Object> metadata = embeddingMatch.embedded().getMetadata();

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.server.listener;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.common.util.ComponentFactory;
@@ -43,8 +42,7 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
.map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(
dataItem.getId().toString() + Constants.UNDERLINE
+ dataItem.getType().name().toLowerCase());
dataItem.getId() + dataItem.getType().name().toLowerCase());
embeddingQuery.setQuery(dataItem.getName());
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
embeddingQuery.setMetadata(meta);

View File

@@ -132,7 +132,7 @@ public class DimensionServiceImpl implements DimensionService {
if (!oldName.equals(dimensionDO.getName())) {
sendEvent(DataItem.builder().modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
.newName(dimensionReq.getName()).name(oldName).type(TypeEnums.DIMENSION)
.id(dimensionDO.getId()).build(), EventType.UPDATE);
.id(dimensionDO.getId() + Constants.UNDERLINE).build(), EventType.UPDATE);
}
}
@@ -366,8 +366,9 @@ public class DimensionServiceImpl implements DimensionService {
private void sendEventBatch(List<DimensionDO> dimensionDOS, EventType eventType) {
List<DataItem> dataItems = dimensionDOS.stream()
.map(dimensionDO -> DataItem.builder().id(dimensionDO.getId()).name(dimensionDO.getName())
.modelId(dimensionDO.getModelId() + Constants.UNDERLINE).type(TypeEnums.DIMENSION).build())
.map(dimensionDO -> DataItem.builder().id(dimensionDO.getId() + Constants.UNDERLINE)
.name(dimensionDO.getName()).modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
.type(TypeEnums.DIMENSION).build())
.collect(Collectors.toList());
eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType));
}

View File

@@ -449,7 +449,7 @@ public class MetricServiceImpl implements MetricService {
private DataItem getDataItem(MetricDO metricDO) {
MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO,
new HashMap<>(), Lists.newArrayList());
return DataItem.builder().id(metricDO.getId()).name(metricDO.getName())
return DataItem.builder().id(metricDO.getId() + Constants.UNDERLINE).name(metricDO.getName())
.bizName(metricDO.getBizName())
.modelId(metricDO.getModelId() + Constants.UNDERLINE)
.type(TypeEnums.METRIC).defaultAgg(metricResp.getDefaultAgg()).build();

View File

@@ -51,7 +51,7 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
</dependency>
</dependencies>

View File

@@ -5,9 +5,9 @@ import static dev.langchain4j.exception.IllegalConfigurationException.illegalCon
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
@@ -248,7 +248,7 @@ public class S2LangChain4jAutoConfiguration {
case IN_PROCESS:
InProcess inProcess = properties.getEmbeddingModel().getInProcess();
if (Objects.isNull(inProcess) || isNullOrBlank(inProcess.getModelPath())) {
return new AllMiniLmL6V2EmbeddingModel();
return new BgeSmallZhEmbeddingModel();
}
return new S2OnnxEmbeddingModel(inProcess.getModelPath(), inProcess.getVocabularyPath());

View File

@@ -1,5 +1,8 @@
package com.tencent.supersonic.chat;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
@@ -8,21 +11,17 @@ import com.tencent.supersonic.chat.core.query.rule.metric.MetricFilterQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricGroupByQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricTopNQuery;
import com.tencent.supersonic.util.DataUtils;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import com.tencent.supersonic.util.DataUtils;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
public class MetricTest extends BaseTest {
@@ -41,6 +40,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
@@ -76,6 +76,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.METRIC);
@@ -105,6 +106,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
@@ -126,6 +128,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
List<String> list = new ArrayList<>();
list.add("alice");
list.add("lucy");
@@ -151,6 +154,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(3, DateConf.DateMode.RECENT, "DAY"));
@@ -171,6 +175,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
@@ -197,6 +202,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));

View File

@@ -30,6 +30,7 @@ public class MultiTurnsTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));

View File

@@ -145,7 +145,7 @@
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
<version>${langchain4j.version}</version>
</dependency>
</dependencies>