(improvement)(chat) The embedding model will be uniformly adopted using the textSegment and will be compatible with the queryId parameter. (#1202)

This commit is contained in:
lexluo09
2024-06-24 13:27:03 +08:00
committed by GitHub
parent a7d367baa3
commit 4b288d9815
13 changed files with 134 additions and 127 deletions

View File

@@ -1,21 +1,21 @@
package dev.langchain4j.chroma.spring;
import static dev.langchain4j.chroma.spring.Properties.PREFIX;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import static dev.langchain4j.chroma.spring.Properties.PREFIX;
@Configuration
@EnableConfigurationProperties(Properties.class)
public class ChromaAutoConfig {
@Bean
@ConditionalOnProperty(PREFIX + ".embedding-store.base-url")
EmbeddingStoreFactory milvusChatModel(Properties properties) {
EmbeddingStoreFactory chromaChatModel(Properties properties) {
return new ChromaEmbeddingStoreFactory(properties);
}
}

View File

@@ -17,7 +17,7 @@ public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory {
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
return ChromaEmbeddingStore.builder()
.baseUrl(embeddingStore.getBaseUrl())
.collectionName(embeddingStore.getCollectionName())
.collectionName(collectionName)
.timeout(embeddingStore.getTimeout())
.build();
}

View File

@@ -1,14 +1,14 @@
package dev.langchain4j.chroma.spring;
import java.time.Duration;
import lombok.Getter;
import lombok.Setter;
import java.time.Duration;
@Getter
@Setter
class EmbeddingStoreProperties {
private String baseUrl;
private String collectionName;
private Duration timeout;
}

View File

@@ -1,18 +1,19 @@
package dev.langchain4j.inmemory.spring;
import dev.langchain4j.store.embedding.EmbeddingQuery;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
private static Map<String, InMemoryEmbeddingStore<TextSegment>> collectionNameToStore =
new ConcurrentHashMap<>();
private Properties properties;
@@ -23,14 +24,12 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
@Override
public synchronized EmbeddingStore create(String collectionName) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = collectionNameToStore.get(collectionName);
InMemoryEmbeddingStore<TextSegment> embeddingStore = collectionNameToStore.get(collectionName);
if (Objects.nonNull(embeddingStore)) {
return embeddingStore;
}
if (Objects.isNull(embeddingStore)) {
embeddingStore = new InMemoryEmbeddingStore();
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
}
embeddingStore = new InMemoryEmbeddingStore();
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
return embeddingStore;
}

View File

@@ -1,35 +0,0 @@
package dev.langchain4j.store.embedding;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.pojo.DataItem;
import lombok.Data;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Data
public class EmbeddingQuery {
private String queryId;
private String query;
private Map<String, Object> metadata;
private List<Double> queryEmbedding;
public static List<EmbeddingQuery> convertToEmbedding(List<DataItem> dataItems) {
return dataItems.stream().map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(
dataItem.getId() + dataItem.getType().name().toLowerCase());
embeddingQuery.setQuery(dataItem.getName());
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
embeddingQuery.setMetadata(meta);
embeddingQuery.setQueryEmbedding(null);
return embeddingQuery;
}).collect(Collectors.toList());
}
}

View File

@@ -0,0 +1,39 @@
package dev.langchain4j.store.embedding;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.pojo.DataItem;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.segment.TextSegment;
import lombok.Data;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
@Data
public class TextSegmentConvert {
public static final String QUERY_ID = "queryId";
public static List<TextSegment> convertToEmbedding(List<DataItem> dataItems) {
return dataItems.stream().map(dataItem -> {
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
TextSegment textSegment = TextSegment.from(dataItem.getName(), new Metadata(meta));
addQueryId(textSegment, dataItem.getId() + dataItem.getType().name().toLowerCase());
return textSegment;
}).collect(Collectors.toList());
}
public static void addQueryId(TextSegment textSegment, String queryId) {
textSegment.metadata().put(QUERY_ID, queryId);
}
public static String getQueryId(TextSegment textSegment) {
if (Objects.isNull(textSegment) || Objects.isNull(textSegment.metadata())) {
return null;
}
return textSegment.metadata().get(QUERY_ID);
}
}