From 59c21ea19a0b6fbc56a64f2f11d424c3dba996fb Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Sat, 16 Dec 2023 20:04:01 +0800 Subject: [PATCH] [improvement](chat) support filterCondition in InMemoryS2EmbeddingStore (#523) --- .../parser/sql/llm/SqlPromptGenerator.java | 4 +- .../embedding/InMemoryS2EmbeddingStore.java | 44 +++++++++++++++++-- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlPromptGenerator.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlPromptGenerator.java index 71c9aa98d..568aa3a5e 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlPromptGenerator.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlPromptGenerator.java @@ -56,8 +56,8 @@ public class SqlPromptGenerator { public String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, List> fewshotExampleList) { String instruction = "# Use the the schema links to generate the SQL queries for each of the questions."; List exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql"); - String exampleTemplate = "dbSchema\nQ: questionAugmented\n" - + "Schema_links: generatedSchemaLinkings\nSQL: {sql}"; + String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n" + + "SQL: {sql}"; String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList); Pair questionPrompt = transformQuestionPrompt(llmReq); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java index 93f709182..c1b56a32c 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/InMemoryS2EmbeddingStore.java @@ -20,11 +20,15 @@ import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Objects; import java.util.PriorityQueue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.MapUtils; +import org.apache.commons.lang3.StringUtils; /*** * Implementation of S2EmbeddingStore within the Java process's in-memory. @@ -75,9 +79,11 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { List results = new ArrayList<>(); List queryTextsList = retrieveQuery.getQueryTextsList(); + Map filterCondition = retrieveQuery.getFilterCondition(); for (String queryText : queryTextsList) { Embedding embeddedText = embeddingModel.embed(queryText).content(); - List> relevant = embeddingStore.findRelevant(embeddedText, num); + int maxResults = getMaxResults(num, filterCondition); + List> relevant = embeddingStore.findRelevant(embeddedText, maxResults); RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult(); retrieveQueryResult.setQuery(queryText); @@ -87,9 +93,17 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { retrieval.setDistance(embeddingMatch.score()); retrieval.setId(embeddingMatch.embeddingId()); retrieval.setQuery(embeddingMatch.embedded().getQuery()); - retrieval.setMetadata(embeddingMatch.embedded().getMetadata()); + Map metadata = embeddingMatch.embedded().getMetadata(); + if (filterRetrieval(filterCondition, metadata)) { + continue; + } + retrieval.setMetadata(metadata); retrievals.add(retrieval); } + retrievals = retrievals.stream() + .sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed()) + .limit(num) + .collect(Collectors.toList()); retrieveQueryResult.setRetrieval(retrievals); results.add(retrieveQueryResult); } @@ -97,14 +111,36 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore { return results; } + private int getMaxResults(int num, Map filterCondition) { + int maxResults = num; + if (MapUtils.isNotEmpty(filterCondition)) { + maxResults = num * 5; + } + return maxResults; + } + + private boolean filterRetrieval(Map filterCondition, Map metadata) { + if (MapUtils.isNotEmpty(metadata) && MapUtils.isNotEmpty(filterCondition)) { + for (Entry entry : metadata.entrySet()) { + String filterValue = filterCondition.get(entry.getKey()); + if (StringUtils.isNotBlank(filterValue) && !filterValue.equalsIgnoreCase( + entry.getValue().toString())) { + return true; + } + } + } + return false; + } + /** * An {@link EmbeddingStore} that stores embeddings in memory. *

* Uses a brute force approach by iterating over all embeddings to find the best matches. + * * @param The class of the object that has been embedded. * Typically, it is {@link dev.langchain4j.data.segment.TextSegment}. - * copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore - * and fix concurrentModificationException in a multi-threaded environment + * copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore + * and fix concurrentModificationException in a multi-threaded environment */ public static class InMemoryEmbeddingStore implements EmbeddingStore {