diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index dd82d52f0..4752c180d 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -26,7 +26,6 @@ import org.apache.commons.collections.MapUtils; import org.apache.commons.collections4.CollectionUtils; import org.springframework.stereotype.Service; -import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; @@ -114,51 +113,57 @@ public class EmbeddingServiceImpl implements EmbeddingService { @Override public List retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) { - List results = new ArrayList<>(); - EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory(); - EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); - List queryTextsList = retrieveQuery.getQueryTextsList(); + EmbeddingStore embeddingStore = EmbeddingStoreFactoryProvider.getFactory().create(collectionName); + EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(); Map filterCondition = retrieveQuery.getFilterCondition(); - for (String queryText : queryTextsList) { - EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(); - Embedding embeddedText = embeddingModel.embed(queryText).content(); - Filter filter = createCombinedFilter(filterCondition); - EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() - .queryEmbedding(embeddedText).filter(filter).maxResults(num).build(); - EmbeddingSearchResult result = embeddingStore.search(request); - List> relevant = result.matches(); - RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult(); - retrieveQueryResult.setQuery(queryText); - List retrievals = new ArrayList<>(); - for (EmbeddingMatch embeddingMatch : relevant) { - Retrieval retrieval = new Retrieval(); - TextSegment embedded = embeddingMatch.embedded(); - retrieval.setDistance(1 - embeddingMatch.score()); - retrieval.setId(TextSegmentConvert.getQueryId(embedded)); - retrieval.setQuery(embedded.text()); - Map metadata = new HashMap<>(); - if (Objects.nonNull(embedded) - && MapUtils.isNotEmpty(embedded.metadata().toMap())) { - metadata.putAll(embedded.metadata().toMap()); - } - retrieval.setMetadata(metadata); - retrievals.add(retrieval); - } - retrievals = retrievals.stream() - .sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed()) - .limit(num) - .collect(Collectors.toList()); - retrieveQueryResult.setRetrieval(retrievals); - results.add(retrieveQueryResult); + return retrieveQuery.getQueryTextsList().stream() + .map(queryText -> retrieveSingleQuery(queryText, embeddingModel, embeddingStore, filterCondition, num)) + .collect(Collectors.toList()); + } + + private RetrieveQueryResult retrieveSingleQuery(String queryText, + EmbeddingModel embeddingModel, + EmbeddingStore embeddingStore, + Map filterCondition, + int num) { + Embedding embeddedText = embeddingModel.embed(queryText).content(); + Filter filter = createCombinedFilter(filterCondition); + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(embeddedText).filter(filter).maxResults(num).build(); + EmbeddingSearchResult result = embeddingStore.search(request); + + List retrievals = result.matches().stream() + .map(this::convertToRetrieval) + .sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed()) + .limit(num) + .collect(Collectors.toList()); + + RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult(); + retrieveQueryResult.setQuery(queryText); + retrieveQueryResult.setRetrieval(retrievals); + return retrieveQueryResult; + } + + private Retrieval convertToRetrieval(EmbeddingMatch embeddingMatch) { + Retrieval retrieval = new Retrieval(); + TextSegment embedded = embeddingMatch.embedded(); + retrieval.setDistance(1 - embeddingMatch.score()); + retrieval.setId(TextSegmentConvert.getQueryId(embedded)); + retrieval.setQuery(embedded.text()); + + Map metadata = new HashMap<>(); + if (Objects.nonNull(embedded) && MapUtils.isNotEmpty(embedded.metadata().toMap())) { + metadata.putAll(embedded.metadata().toMap()); } - return results; + retrieval.setMetadata(metadata); + return retrieval; } private static Filter createCombinedFilter(Map map) { - Filter result = null; if (MapUtils.isEmpty(map)) { return null; } + Filter result = null; for (Map.Entry entry : map.entrySet()) { IsEqualTo isEqualTo = new IsEqualTo(entry.getKey(), entry.getValue()); result = (result == null) ? isEqualTo : Filter.and(result, isEqualTo); diff --git a/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java b/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java index 8460cd98a..d7741623d 100644 --- a/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java @@ -12,7 +12,8 @@ import org.springframework.stereotype.Service; @Service public class DashscopeModelFactory implements ModelFactory, InitializingBean { public static final String PROVIDER = "DASHSCOPE"; - public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"; + public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1"; + public static final String DEFAULT_COMPATIBLE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"; @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {