(improvement)(chat) The embedding model will be uniformly adopted using the textSegment and will be compatible with the queryId parameter. (#1202)

This commit is contained in:
lexluo09
2024-06-24 13:27:03 +08:00
committed by GitHub
parent a7d367baa3
commit 4b288d9815
13 changed files with 134 additions and 127 deletions

View File

@@ -1,8 +1,9 @@
package com.tencent.supersonic.common.service;
import dev.langchain4j.store.embedding.EmbeddingQuery;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import java.util.List;
/**
@@ -13,9 +14,9 @@ public interface EmbeddingService {
void addCollection(String collectionName);
void addQuery(String collectionName, List<EmbeddingQuery> queries);
void addQuery(String collectionName, List<TextSegment> queries);
void deleteQuery(String collectionName, List<EmbeddingQuery> queries);
void deleteQuery(String collectionName, List<TextSegment> queries);
List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num);

View File

@@ -2,9 +2,9 @@ package com.tencent.supersonic.common.service.impl;
import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingQuery;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore;
@@ -12,8 +12,13 @@ import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import org.apache.commons.collections.MapUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
@@ -21,9 +26,6 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.collections.MapUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
public class EmbeddingServiceImpl implements EmbeddingService {
@@ -39,17 +41,17 @@ public class EmbeddingServiceImpl implements EmbeddingService {
}
@Override
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
public void addQuery(String collectionName, List<TextSegment> queries) {
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
for (EmbeddingQuery query : queries) {
String question = query.getQuery();
for (TextSegment query : queries) {
String question = query.text();
Embedding embedding = embeddingModel.embed(question).content();
embeddingStore.add(embedding, query);
}
}
@Override
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
public void deleteQuery(String collectionName, List<TextSegment> queries) {
}
@Override
@@ -66,21 +68,21 @@ public class EmbeddingServiceImpl implements EmbeddingService {
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
EmbeddingSearchResult result = embeddingStore.search(request);
List<EmbeddingMatch<EmbeddingQuery>> relevant = result.matches();
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
retrieveQueryResult.setQuery(queryText);
List<Retrieval> retrievals = new ArrayList<>();
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
for (EmbeddingMatch<TextSegment> embeddingMatch : relevant) {
Retrieval retrieval = new Retrieval();
EmbeddingQuery embedded = embeddingMatch.embedded();
TextSegment embedded = embeddingMatch.embedded();
retrieval.setDistance(1 - embeddingMatch.score());
retrieval.setId(embedded.getQueryId());
retrieval.setQuery(embedded.getQuery());
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
retrieval.setQuery(embedded.text());
Map<String, Object> metadata = new HashMap<>();
if (Objects.nonNull(embedded)
&& MapUtils.isNotEmpty(embedded.getMetadata())) {
metadata.putAll(embedded.getMetadata());
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
metadata.putAll(embedded.metadata().toMap());
}
retrieval.setMetadata(metadata);
retrievals.add(retrieval);