mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 13:07:32 +00:00
[improvement](chat) support filterCondition in InMemoryS2EmbeddingStore (#523)
This commit is contained in:
@@ -56,8 +56,8 @@ public class SqlPromptGenerator {
|
|||||||
public String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, List<Map<String, String>> fewshotExampleList) {
|
public String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr, List<Map<String, String>> fewshotExampleList) {
|
||||||
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
|
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
|
||||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
|
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
|
||||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\n"
|
String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n"
|
||||||
+ "Schema_links: generatedSchemaLinkings\nSQL: {sql}";
|
+ "SQL: {sql}";
|
||||||
|
|
||||||
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
|
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
|
||||||
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
|
Pair<String, String> questionPrompt = transformQuestionPrompt(llmReq);
|
||||||
|
|||||||
@@ -20,11 +20,15 @@ import java.util.Collections;
|
|||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Map.Entry;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.PriorityQueue;
|
import java.util.PriorityQueue;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import java.util.concurrent.CopyOnWriteArrayList;
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.MapUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
/***
|
/***
|
||||||
* Implementation of S2EmbeddingStore within the Java process's in-memory.
|
* Implementation of S2EmbeddingStore within the Java process's in-memory.
|
||||||
@@ -75,9 +79,11 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
List<RetrieveQueryResult> results = new ArrayList<>();
|
List<RetrieveQueryResult> results = new ArrayList<>();
|
||||||
|
|
||||||
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
List<String> queryTextsList = retrieveQuery.getQueryTextsList();
|
||||||
|
Map<String, String> filterCondition = retrieveQuery.getFilterCondition();
|
||||||
for (String queryText : queryTextsList) {
|
for (String queryText : queryTextsList) {
|
||||||
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
Embedding embeddedText = embeddingModel.embed(queryText).content();
|
||||||
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, num);
|
int maxResults = getMaxResults(num, filterCondition);
|
||||||
|
List<EmbeddingMatch<EmbeddingQuery>> relevant = embeddingStore.findRelevant(embeddedText, maxResults);
|
||||||
|
|
||||||
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
||||||
retrieveQueryResult.setQuery(queryText);
|
retrieveQueryResult.setQuery(queryText);
|
||||||
@@ -87,9 +93,17 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
retrieval.setDistance(embeddingMatch.score());
|
retrieval.setDistance(embeddingMatch.score());
|
||||||
retrieval.setId(embeddingMatch.embeddingId());
|
retrieval.setId(embeddingMatch.embeddingId());
|
||||||
retrieval.setQuery(embeddingMatch.embedded().getQuery());
|
retrieval.setQuery(embeddingMatch.embedded().getQuery());
|
||||||
retrieval.setMetadata(embeddingMatch.embedded().getMetadata());
|
Map<String, Object> metadata = embeddingMatch.embedded().getMetadata();
|
||||||
|
if (filterRetrieval(filterCondition, metadata)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
retrieval.setMetadata(metadata);
|
||||||
retrievals.add(retrieval);
|
retrievals.add(retrieval);
|
||||||
}
|
}
|
||||||
|
retrievals = retrievals.stream()
|
||||||
|
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
|
||||||
|
.limit(num)
|
||||||
|
.collect(Collectors.toList());
|
||||||
retrieveQueryResult.setRetrieval(retrievals);
|
retrieveQueryResult.setRetrieval(retrievals);
|
||||||
results.add(retrieveQueryResult);
|
results.add(retrieveQueryResult);
|
||||||
}
|
}
|
||||||
@@ -97,14 +111,36 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private int getMaxResults(int num, Map<String, String> filterCondition) {
|
||||||
|
int maxResults = num;
|
||||||
|
if (MapUtils.isNotEmpty(filterCondition)) {
|
||||||
|
maxResults = num * 5;
|
||||||
|
}
|
||||||
|
return maxResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean filterRetrieval(Map<String, String> filterCondition, Map<String, Object> metadata) {
|
||||||
|
if (MapUtils.isNotEmpty(metadata) && MapUtils.isNotEmpty(filterCondition)) {
|
||||||
|
for (Entry<String, Object> entry : metadata.entrySet()) {
|
||||||
|
String filterValue = filterCondition.get(entry.getKey());
|
||||||
|
if (StringUtils.isNotBlank(filterValue) && !filterValue.equalsIgnoreCase(
|
||||||
|
entry.getValue().toString())) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An {@link EmbeddingStore} that stores embeddings in memory.
|
* An {@link EmbeddingStore} that stores embeddings in memory.
|
||||||
* <p>
|
* <p>
|
||||||
* Uses a brute force approach by iterating over all embeddings to find the best matches.
|
* Uses a brute force approach by iterating over all embeddings to find the best matches.
|
||||||
|
*
|
||||||
* @param <Embedded> The class of the object that has been embedded.
|
* @param <Embedded> The class of the object that has been embedded.
|
||||||
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
|
* Typically, it is {@link dev.langchain4j.data.segment.TextSegment}.
|
||||||
* copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
|
* copy from dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore
|
||||||
* and fix concurrentModificationException in a multi-threaded environment
|
* and fix concurrentModificationException in a multi-threaded environment
|
||||||
*/
|
*/
|
||||||
public static class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
|
public static class InMemoryEmbeddingStore<Embedded> implements EmbeddingStore<Embedded> {
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user