[improvement](chat) support filterCondition in InMemoryS2EmbeddingStore (#523)

This commit is contained in:
lexluo09
2023-12-16 20:04:01 +08:00
committed by GitHub
parent 95334441b1
commit 59c21ea19a
2 changed files with 42 additions and 6 deletions

View File

@@ -56,8 +56,8 @@ public class SqlPromptGenerator {
public String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, List<Map<String, String>> fewshotExampleList) { public String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, List<Map<String, String>> fewshotExampleList) {
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions."; String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql"); List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
String exampleTemplate = "dbSchema\nQ: questionAugmented\n" String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n"
+ "Schema_links: generatedSchemaLinkings\nSQL: {sql}"; + "SQL: {sql}";
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList); String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq); Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);

View File

@@ -20,11 +20,15 @@ import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects; import java.util.Objects;
import java.util.PriorityQueue; import java.util.PriorityQueue;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import org.apache.commons.lang3.StringUtils;
/*** /***
* Implementation of S2EmbeddingStore within the Java process's in-memory. * Implementation of S2EmbeddingStore within the Java process's in-memory.
@@ -75,9 +79,11 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
List<RetrieveQueryResult> results = new ArrayList<>(); List<RetrieveQueryResult> results = new ArrayList<>();
List<String> queryTextsList = retrieveQuery.getQueryTextsList(); List<String> queryTextsList = retrieveQuery.getQueryTextsList();
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
for (String queryText : queryTextsList) { for (String queryText : queryTextsList) {
Embedding embeddedText = embeddingModel.embed(queryText).content(); Embedding embeddedText = embeddingModel.embed(queryText).content();
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, num); int maxResults = getMaxResults(num, filterCondition);
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, maxResults);
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult(); RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
retrieveQueryResult.setQuery(queryText); retrieveQueryResult.setQuery(queryText);
@@ -87,9 +93,17 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
retrieval.setDistance(embeddingMatch.score()); retrieval.setDistance(embeddingMatch.score());
retrieval.setId(embeddingMatch.embeddingId()); retrieval.setId(embeddingMatch.embeddingId());
retrieval.setQuery(embeddingMatch.embedded().getQuery()); retrieval.setQuery(embeddingMatch.embedded().getQuery());
retrieval.setMetadata(embeddingMatch.embedded().getMetadata()); Map<String, Object> metadata = embeddingMatch.embedded().getMetadata();
if (filterRetrieval(filterCondition, metadata)) {
continue;
}
retrieval.setMetadata(metadata);
retrievals.add(retrieval); retrievals.add(retrieval);
} }
retrievals = retrievals.stream()
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
.limit(num)
.collect(Collectors.toList());
retrieveQueryResult.setRetrieval(retrievals); retrieveQueryResult.setRetrieval(retrievals);
results.add(retrieveQueryResult); results.add(retrieveQueryResult);
} }
@@ -97,10 +111,32 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
return results; return results;
} }
private int getMaxResults(int num, Map<String, String> filterCondition) {
int maxResults = num;
if (MapUtils.isNotEmpty(filterCondition)) {
maxResults = num * 5;
}
return maxResults;
}
private boolean filterRetrieval(Map<String, String> filterCondition, Map<String, Object> metadata) {
if (MapUtils.isNotEmpty(metadata) && MapUtils.isNotEmpty(filterCondition)) {
for (Entry<String, Object> entry : metadata.entrySet()) {
String filterValue = filterCondition.get(entry.getKey());
if (StringUtils.isNotBlank(filterValue) && !filterValue.equalsIgnoreCase(
entry.getValue().toString())) {
return true;
}
}
}
return false;
}
/** /**
* An {@link EmbeddingStore} that stores embeddings in memory. * An {@link EmbeddingStore} that stores embeddings in memory.
* <p> * <p>
* Uses a brute force approach by iterating over all embeddings to find the best matches. * Uses a brute force approach by iterating over all embeddings to find the best matches.
*
* @param <Embedded> The class of the object that has been embedded. * @param <Embedded> The class of the object that has been embedded.
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}. * Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
* copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore * copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore