mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 22:08:56 +00:00
(improvement)(headless) Headless integration embedding functionality, with support for viewId in embeddings. (#725)
This commit is contained in:
@@ -44,7 +44,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
return result;
|
||||
}
|
||||
|
||||
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectModelIds) {
|
||||
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds) {
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||
String text = queryContext.getQueryText();
|
||||
Set<T> results = new HashSet<>();
|
||||
@@ -59,16 +59,16 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
if (index <= text.length()) {
|
||||
String detectSegment = text.substring(startIndex, index);
|
||||
detectSegments.add(detectSegment);
|
||||
detectByStep(queryContext, results, detectModelIds, startIndex, index, offset);
|
||||
detectByStep(queryContext, results, detectViewIds, startIndex, index, offset);
|
||||
}
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
}
|
||||
detectByBatch(queryContext, results, detectModelIds, detectSegments);
|
||||
detectByBatch(queryContext, results, detectViewIds, detectSegments);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectModelIds,
|
||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectViewIds,
|
||||
Set<String> detectSegments) {
|
||||
return;
|
||||
}
|
||||
@@ -152,6 +152,6 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
public abstract String getMapKey(T a);
|
||||
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
|
||||
Set<Long> detectModelIds, Integer startIndex, Integer index, int offset);
|
||||
Set<Long> detectViewIds, Integer startIndex, Integer index, int offset);
|
||||
|
||||
}
|
||||
|
||||
@@ -37,9 +37,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
Set<Long> detectViewIds) {
|
||||
this.allElements = getSchemaElements(queryContext);
|
||||
return super.match(queryContext, terms, detectModelIds);
|
||||
return super.match(queryContext, terms, detectViewIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -54,7 +54,7 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectModelIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
String detectSegment = queryContext.getQueryText().substring(startIndex, index);
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
|
||||
@@ -1,20 +1,18 @@
|
||||
package com.tencent.supersonic.chat.core.mapper;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.headless.core.knowledge.helper.HanlpHelper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/***
|
||||
* A mapper that recognizes schema elements with vector embedding.
|
||||
@@ -39,15 +37,11 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
|
||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
||||
SchemaElement.class);
|
||||
if (Objects.isNull(matchResult.getMetadata())) {
|
||||
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
|
||||
if (Objects.isNull(viewId)) {
|
||||
continue;
|
||||
}
|
||||
String modelIdStr = matchResult.getMetadata().get("modelId");
|
||||
if (StringUtils.isBlank(modelIdStr)) {
|
||||
continue;
|
||||
}
|
||||
long modelId = Long.parseLong(modelIdStr);
|
||||
schemaElement = getSchemaElement(modelId, schemaElement.getType(), elementId,
|
||||
schemaElement = getSchemaElement(viewId, schemaElement.getType(), elementId,
|
||||
queryContext.getSemanticSchema());
|
||||
if (schemaElement == null) {
|
||||
continue;
|
||||
@@ -60,7 +54,7 @@ public class EmbeddingMapper extends BaseMapper {
|
||||
.detectWord(matchResult.getDetectWord())
|
||||
.build();
|
||||
//3. add to mapInfo
|
||||
addToSchemaMap(queryContext.getMapInfo(), modelId, schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), viewId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,15 +4,13 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
@@ -36,9 +34,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
private OptimizationConfig optimizationConfig;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
private MetaEmbeddingService metaEmbeddingService;
|
||||
|
||||
@Override
|
||||
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||
@@ -52,7 +48,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||
Set<String> detectSegments) {
|
||||
|
||||
List<String> queryTextsList = detectSegments.stream()
|
||||
@@ -66,51 +62,29 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
optimizationConfig.getEmbeddingMapperBatch());
|
||||
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
detectByQueryTextsSub(results, detectModelIds, queryTextsSub);
|
||||
detectByQueryTextsSub(results, detectViewIds, queryTextsSub);
|
||||
}
|
||||
}
|
||||
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectModelIds,
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||
List<String> queryTextsSub) {
|
||||
int embeddingNumber = optimizationConfig.getEmbeddingMapperNumber();
|
||||
Double distance = optimizationConfig.getEmbeddingMapperDistanceThreshold();
|
||||
Map<String, String> filterCondition = null;
|
||||
// step1. build query params
|
||||
// if only one modelId, add to filterCondition
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds) && detectModelIds.size() == 1) {
|
||||
filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", detectModelIds.stream().findFirst().get().toString());
|
||||
}
|
||||
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||
.queryTextsList(queryTextsSub)
|
||||
.filterCondition(filterCondition)
|
||||
.queryEmbeddings(null)
|
||||
.build();
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
// step2. retrieveQuery by detectSegment
|
||||
List<RetrieveQueryResult> retrieveQueryResults = s2EmbeddingStore.retrieveQuery(
|
||||
embeddingConfig.getMetaCollectionName(), retrieveQuery, embeddingNumber);
|
||||
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||
new ArrayList<>(detectViewIds), retrieveQuery, embeddingNumber);
|
||||
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
}
|
||||
// step3. build EmbeddingResults. filter by modelId
|
||||
// step3. build EmbeddingResults
|
||||
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
||||
.map(retrieveQueryResult -> {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||
retrievals.removeIf(retrieval -> retrieval.getDistance() > distance.doubleValue());
|
||||
if (CollectionUtils.isNotEmpty(detectModelIds)) {
|
||||
retrievals.removeIf(retrieval -> {
|
||||
String modelIdStr = retrieval.getMetadata().get("modelId").toString();
|
||||
if (StringUtils.isBlank(modelIdStr)) {
|
||||
return true;
|
||||
}
|
||||
//return detectModelIds.contains(Long.parseLong(modelIdStr));
|
||||
Double modelId = Double.parseDouble(modelIdStr);
|
||||
return detectModelIds.contains(modelId.longValue());
|
||||
});
|
||||
}
|
||||
}
|
||||
return retrieveQueryResult;
|
||||
})
|
||||
@@ -121,6 +95,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
||||
embeddingResult.setName(retrieval.getQuery());
|
||||
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toString()));
|
||||
embeddingResult.setMetadata(convertedMap);
|
||||
return embeddingResult;
|
||||
}))
|
||||
.collect(Collectors.toList());
|
||||
@@ -135,7 +112,7 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectModelIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -36,15 +36,15 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectModelIds) {
|
||||
Set<Long> detectViewIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectModelIds);
|
||||
log.debug("retryCount:{},terms:{},,detectModelIds:{}", terms, detectViewIds);
|
||||
|
||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectModelIds);
|
||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectViewIds);
|
||||
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
@@ -57,7 +57,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectModelIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||
Integer startIndex, Integer index, int offset) {
|
||||
String text = queryContext.getQueryText();
|
||||
String detectSegment = text.substring(startIndex, index);
|
||||
@@ -65,11 +65,10 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
|
||||
detectModelIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(detectSegment,
|
||||
oneDetectionMaxSize, detectModelIds).stream()
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
oneDetectionMaxSize, detectViewIds).stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
|
||||
|
||||
@@ -13,6 +13,6 @@ import java.util.Set;
|
||||
*/
|
||||
public interface MatchStrategy<T> {
|
||||
|
||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectModelId);
|
||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectViewIds);
|
||||
|
||||
}
|
||||
@@ -27,7 +27,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
||||
Set<Long> detectModelIds) {
|
||||
Set<Long> detectViewIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||
|
||||
@@ -52,9 +52,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||
List<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment,
|
||||
SearchService.SEARCH_SIZE, detectModelIds);
|
||||
SearchService.SEARCH_SIZE, detectViewIds);
|
||||
List<HanlpMapResult> suffixHanlpMapResults = SearchService.suffixSearch(
|
||||
detectSegment, SEARCH_SIZE, detectModelIds);
|
||||
detectSegment, SEARCH_SIZE, detectViewIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||
@@ -88,7 +88,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectModelIds,
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectViewIds,
|
||||
Integer startIndex,
|
||||
Integer i, int offset) {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user