mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvement)(chat) Fix the path issue in Dashscope for large models and optimize the EmbeddingService code. (#1566)
This commit is contained in:
@@ -26,7 +26,6 @@ import org.apache.commons.collections.MapUtils;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
@@ -114,51 +113,57 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
|
||||
@Override
|
||||
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
|
||||
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||
for (String queryText : queryTextsList) {
|
||||
EmbeddingStore embeddingStore = EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
|
||||
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
|
||||
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||
return retrieveQuery.getQueryTextsList().stream()
|
||||
.map(queryText -> retrieveSingleQuery(queryText, embeddingModel, embeddingStore, filterCondition, num))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private RetrieveQueryResult retrieveSingleQuery(String queryText,
|
||||
EmbeddingModel embeddingModel,
|
||||
EmbeddingStore embeddingStore,
|
||||
Map<String, String> filterCondition,
|
||||
int num) {
|
||||
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
||||
Filter filter = createCombinedFilter(filterCondition);
|
||||
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
||||
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
|
||||
EmbeddingSearchResult result = embeddingStore.search(request);
|
||||
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
||||
EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
|
||||
|
||||
List<Retrieval> retrievals = result.matches().stream()
|
||||
.map(this::convertToRetrieval)
|
||||
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
|
||||
.limit(num)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
||||
retrieveQueryResult.setQuery(queryText);
|
||||
List<Retrieval> retrievals = new ArrayList<>();
|
||||
for (EmbeddingMatch<TextSegment> embeddingMatch : relevant) {
|
||||
retrieveQueryResult.setRetrieval(retrievals);
|
||||
return retrieveQueryResult;
|
||||
}
|
||||
|
||||
private Retrieval convertToRetrieval(EmbeddingMatch<TextSegment> embeddingMatch) {
|
||||
Retrieval retrieval = new Retrieval();
|
||||
TextSegment embedded = embeddingMatch.embedded();
|
||||
retrieval.setDistance(1 - embeddingMatch.score());
|
||||
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
|
||||
retrieval.setQuery(embedded.text());
|
||||
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
if (Objects.nonNull(embedded)
|
||||
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
|
||||
if (Objects.nonNull(embedded) && MapUtils.isNotEmpty(embedded.metadata().toMap())) {
|
||||
metadata.putAll(embedded.metadata().toMap());
|
||||
}
|
||||
retrieval.setMetadata(metadata);
|
||||
retrievals.add(retrieval);
|
||||
}
|
||||
retrievals = retrievals.stream()
|
||||
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
|
||||
.limit(num)
|
||||
.collect(Collectors.toList());
|
||||
retrieveQueryResult.setRetrieval(retrievals);
|
||||
results.add(retrieveQueryResult);
|
||||
}
|
||||
return results;
|
||||
return retrieval;
|
||||
}
|
||||
|
||||
private static Filter createCombinedFilter(Map<String, String> map) {
|
||||
Filter result = null;
|
||||
if (MapUtils.isEmpty(map)) {
|
||||
return null;
|
||||
}
|
||||
Filter result = null;
|
||||
for (Map.Entry<String, String> entry : map.entrySet()) {
|
||||
IsEqualTo isEqualTo = new IsEqualTo(entry.getKey(), entry.getValue());
|
||||
result = (result == null) ? isEqualTo : Filter.and(result, isEqualTo);
|
||||
|
||||
@@ -12,7 +12,8 @@ import org.springframework.stereotype.Service;
|
||||
@Service
|
||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
public static final String PROVIDER = "DASHSCOPE";
|
||||
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1";
|
||||
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
|
||||
public static final String DEFAULT_COMPATIBLE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1";
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
|
||||
Reference in New Issue
Block a user