(improvement)(chat) Fix the path issue in Dashscope for large models and optimize the EmbeddingService code. (#1566)

This commit is contained in:
lexluo09
2024-08-13 13:37:24 +08:00
committed by GitHub
parent 93fedd787f
commit 8d29e89317
2 changed files with 45 additions and 39 deletions

View File

@@ -26,7 +26,6 @@ import org.apache.commons.collections.MapUtils;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
@@ -114,51 +113,57 @@ public class EmbeddingServiceImpl implements EmbeddingService {
@Override
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
List<RetrieveQueryResult> results = new ArrayList<>();
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory();
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) {
EmbeddingStore embeddingStore = EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
return retrieveQuery.getQueryTextsList().stream()
.map(queryText -> retrieveSingleQuery(queryText, embeddingModel, embeddingStore, filterCondition, num))
.collect(Collectors.toList());
}
private RetrieveQueryResult retrieveSingleQuery(String queryText,
EmbeddingModel embeddingModel,
EmbeddingStore embeddingStore,
Map<String, String> filterCondition,
int num) {
Embedding embeddedText = embeddingModel.embed(queryText).content();
Filter filter = createCombinedFilter(filterCondition);
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
EmbeddingSearchResult result = embeddingStore.search(request);
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
List<Retrieval> retrievals = result.matches().stream()
.map(this::convertToRetrieval)
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
.limit(num)
.collect(Collectors.toList());
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
retrieveQueryResult.setQuery(queryText);
List<Retrieval> retrievals = new ArrayList<>();
for (EmbeddingMatch<TextSegment> embeddingMatch : relevant) {
retrieveQueryResult.setRetrieval(retrievals);
return retrieveQueryResult;
}
private Retrieval convertToRetrieval(EmbeddingMatch<TextSegment> embeddingMatch) {
Retrieval retrieval = new Retrieval();
TextSegment embedded = embeddingMatch.embedded();
retrieval.setDistance(1 - embeddingMatch.score());
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
retrieval.setQuery(embedded.text());
Map<String, Object> metadata = new HashMap<>();
if (Objects.nonNull(embedded)
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
if (Objects.nonNull(embedded) && MapUtils.isNotEmpty(embedded.metadata().toMap())) {
metadata.putAll(embedded.metadata().toMap());
}
retrieval.setMetadata(metadata);
retrievals.add(retrieval);
}
retrievals = retrievals.stream()
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
.limit(num)
.collect(Collectors.toList());
retrieveQueryResult.setRetrieval(retrievals);
results.add(retrieveQueryResult);
}
return results;
return retrieval;
}
private static Filter createCombinedFilter(Map<String, String> map) {
Filter result = null;
if (MapUtils.isEmpty(map)) {
return null;
}
Filter result = null;
for (Map.Entry<String, String> entry : map.entrySet()) {
IsEqualTo isEqualTo = new IsEqualTo(entry.getKey(), entry.getValue());
result = (result == null) ? isEqualTo : Filter.and(result, isEqualTo);

View File

@@ -12,7 +12,8 @@ import org.springframework.stereotype.Service;
@Service
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "DASHSCOPE";
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1";
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
public static final String DEFAULT_COMPATIBLE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1";
@Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {