[improvement](chat) support filterCondition in InMemoryS2EmbeddingStore (#523)

This commit is contained in:
lexluo09
2023-12-16 20:04:01 +08:00
committed by GitHub
parent 95334441b1
commit 59c21ea19a
2 changed files with 42 additions and 6 deletions

View File

@@ -56,8 +56,8 @@ public class SqlPromptGenerator {
public String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, List<Map<String, String>> fewshotExampleList) {
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
String exampleTemplate = "dbSchema\nQ: questionAugmented\n"
+ "Schema_links: generatedSchemaLinkings\nSQL: {sql}";
String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n"
+ "SQL: {sql}";
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);

View File

@@ -20,11 +20,15 @@ import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
/***
* Implementation of S2EmbeddingStore within the Java process's in-memory.
@@ -75,9 +79,11 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
List<RetrieveQueryResult> results = new ArrayList<>();
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) {
Embedding embeddedText = embeddingModel.embed(queryText).content();
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, num);
int maxResults = getMaxResults(num, filterCondition);
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, maxResults);
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
retrieveQueryResult.setQuery(queryText);
@@ -87,9 +93,17 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
retrieval.setDistance(embeddingMatch.score());
retrieval.setId(embeddingMatch.embeddingId());
retrieval.setQuery(embeddingMatch.embedded().getQuery());
retrieval.setMetadata(embeddingMatch.embedded().getMetadata());
Map<String, Object> metadata = embeddingMatch.embedded().getMetadata();
if (filterRetrieval(filterCondition, metadata)) {
continue;
}
retrieval.setMetadata(metadata);
retrievals.add(retrieval);
}
retrievals = retrievals.stream()
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
.limit(num)
.collect(Collectors.toList());
retrieveQueryResult.setRetrieval(retrievals);
results.add(retrieveQueryResult);
}
@@ -97,14 +111,36 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
return results;
}
private int getMaxResults(int num, Map<String, String> filterCondition) {
int maxResults = num;
if (MapUtils.isNotEmpty(filterCondition)) {
maxResults = num * 5;
}
return maxResults;
}
private boolean filterRetrieval(Map<String, String> filterCondition, Map<String, Object> metadata) {
if (MapUtils.isNotEmpty(metadata) && MapUtils.isNotEmpty(filterCondition)) {
for (Entry<String, Object> entry : metadata.entrySet()) {
String filterValue = filterCondition.get(entry.getKey());
if (StringUtils.isNotBlank(filterValue) && !filterValue.equalsIgnoreCase(
entry.getValue().toString())) {
return true;
}
}
}
return false;
}
/**
* An {@link EmbeddingStore} that stores embeddings in memory.
* <p>
* Uses a brute force approach by iterating over all embeddings to find the best matches.
*
* @param <Embedded> The class of the object that has been embedded.
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
* copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
* and fix concurrentModificationException in a multi-threaded environment
* copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
* and fix concurrentModificationException in a multi-threaded environment
*/
public static class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {