(improvement)(chat) Fix the path issue in Dashscope for large models and optimize the EmbeddingService code. (#1566)

This commit is contained in:
lexluo09
2024-08-13 13:37:24 +08:00
committed by GitHub
parent 93fedd787f
commit 8d29e89317
2 changed files with 45 additions and 39 deletions

View File

@@ -26,7 +26,6 @@ import org.apache.commons.collections.MapUtils;
import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.CollectionUtils;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
@@ -114,51 +113,57 @@ public class EmbeddingServiceImpl implements EmbeddingService {
@Override @Override
public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) { public List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num) {
List<RetrieveQueryResult> results = new ArrayList<>(); EmbeddingStore embeddingStore = EmbeddingStoreFactoryProvider.getFactory().create(collectionName);
EmbeddingStoreFactory embeddingStoreFactory = EmbeddingStoreFactoryProvider.getFactory(); EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel();
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition(); Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) { return retrieveQuery.getQueryTextsList().stream()
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(); .map(queryText -> retrieveSingleQuery(queryText, embeddingModel, embeddingStore, filterCondition, num))
Embedding embeddedText = embeddingModel.embed(queryText).content(); .collect(Collectors.toList());
Filter filter = createCombinedFilter(filterCondition); }
EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build(); private RetrieveQueryResult retrieveSingleQuery(String queryText,
EmbeddingSearchResult result = embeddingStore.search(request); EmbeddingModel embeddingModel,
List<EmbeddingMatch<TextSegment>> relevant = result.matches(); EmbeddingStore embeddingStore,
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult(); Map<String, String> filterCondition,
retrieveQueryResult.setQuery(queryText); int num) {
List<Retrieval> retrievals = new ArrayList<>(); Embedding embeddedText = embeddingModel.embed(queryText).content();
for (EmbeddingMatch<TextSegment> embeddingMatch : relevant) { Filter filter = createCombinedFilter(filterCondition);
Retrieval retrieval = new Retrieval(); EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
TextSegment embedded = embeddingMatch.embedded(); .queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
retrieval.setDistance(1 - embeddingMatch.score()); EmbeddingSearchResult<TextSegment> result = embeddingStore.search(request);
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
retrieval.setQuery(embedded.text()); List<Retrieval> retrievals = result.matches().stream()
Map<String, Object> metadata = new HashMap<>(); .map(this::convertToRetrieval)
if (Objects.nonNull(embedded) .sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) { .limit(num)
metadata.putAll(embedded.metadata().toMap()); .collect(Collectors.toList());
}
retrieval.setMetadata(metadata); RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
retrievals.add(retrieval); retrieveQueryResult.setQuery(queryText);
} retrieveQueryResult.setRetrieval(retrievals);
retrievals = retrievals.stream() return retrieveQueryResult;
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed()) }
.limit(num)
.collect(Collectors.toList()); private Retrieval convertToRetrieval(EmbeddingMatch<TextSegment> embeddingMatch) {
retrieveQueryResult.setRetrieval(retrievals); Retrieval retrieval = new Retrieval();
results.add(retrieveQueryResult); TextSegment embedded = embeddingMatch.embedded();
retrieval.setDistance(1 - embeddingMatch.score());
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
retrieval.setQuery(embedded.text());
Map<String, Object> metadata = new HashMap<>();
if (Objects.nonNull(embedded) && MapUtils.isNotEmpty(embedded.metadata().toMap())) {
metadata.putAll(embedded.metadata().toMap());
} }
return results; retrieval.setMetadata(metadata);
return retrieval;
} }
private static Filter createCombinedFilter(Map<String, String> map) { private static Filter createCombinedFilter(Map<String, String> map) {
Filter result = null;
if (MapUtils.isEmpty(map)) { if (MapUtils.isEmpty(map)) {
return null; return null;
} }
Filter result = null;
for (Map.Entry<String, String> entry : map.entrySet()) { for (Map.Entry<String, String> entry : map.entrySet()) {
IsEqualTo isEqualTo = new IsEqualTo(entry.getKey(), entry.getValue()); IsEqualTo isEqualTo = new IsEqualTo(entry.getKey(), entry.getValue());
result = (result == null) ? isEqualTo : Filter.and(result, isEqualTo); result = (result == null) ? isEqualTo : Filter.and(result, isEqualTo);

View File

@@ -12,7 +12,8 @@ import org.springframework.stereotype.Service;
@Service @Service
public class DashscopeModelFactory implements ModelFactory, InitializingBean { public class DashscopeModelFactory implements ModelFactory, InitializingBean {
public static final String PROVIDER = "DASHSCOPE"; public static final String PROVIDER = "DASHSCOPE";
public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"; public static final String DEFAULT_BASE_URL = "https://dashscope.aliyuncs.com/api/v1";
public static final String DEFAULT_COMPATIBLE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1";
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {