From 4b288d981508f143e2c4589c04015778cf14b6c9 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Mon, 24 Jun 2024 13:27:03 +0800 Subject: [PATCH] (improvement)(chat) The embedding model will be uniformly adopted using the textSegment and will be compatible with the queryId parameter. (#1202) --- .../chat/server/plugin/PluginManager.java | 38 ++++++++---------- .../chat/server/util/SimilarQueryManager.java | 31 ++++++++------- .../common/service/EmbeddingService.java | 7 ++-- .../service/impl/EmbeddingServiceImpl.java | 32 ++++++++------- .../chroma/spring/ChromaAutoConfig.java | 6 +-- .../spring/ChromaEmbeddingStoreFactory.java | 2 +- .../spring/EmbeddingStoreProperties.java | 4 +- .../spring/InMemoryEmbeddingStoreFactory.java | 15 ++++--- .../store/embedding/EmbeddingQuery.java | 35 ----------------- .../store/embedding/TextSegmentConvert.java | 39 +++++++++++++++++++ .../chat/parser/llm/ExemplarManager.java | 23 +++++------ .../listener/MetaEmbeddingListener.java | 18 +++++---- .../server/schedule/EmbeddingTask.java | 11 +++--- 13 files changed, 134 insertions(+), 127 deletions(-) delete mode 100644 common/src/main/java/dev/langchain4j/store/embedding/EmbeddingQuery.java create mode 100644 common/src/main/java/dev/langchain4j/store/embedding/TextSegmentConvert.java diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java index 245268b66..a8c8b9734 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java @@ -14,16 +14,17 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.service.PluginService; import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.common.util.ContextUtils; -import dev.langchain4j.store.embedding.EmbeddingQuery; import com.tencent.supersonic.common.service.EmbeddingService; -import dev.langchain4j.store.embedding.Retrieval; -import dev.langchain4j.store.embedding.RetrieveQuery; -import dev.langchain4j.store.embedding.RetrieveQueryResult; +import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.Retrieval; +import dev.langchain4j.store.embedding.RetrieveQuery; +import dev.langchain4j.store.embedding.RetrieveQueryResult; +import dev.langchain4j.store.embedding.TextSegmentConvert; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.tuple.Pair; @@ -116,16 +117,16 @@ public class PluginManager { } String presetCollection = embeddingConfig.getPresetCollection(); - List queries = new ArrayList<>(); + List queries = new ArrayList<>(); for (String id : queryIds) { - EmbeddingQuery embeddingQuery = new EmbeddingQuery(); - embeddingQuery.setQueryId(id); - queries.add(embeddingQuery); + TextSegment query = TextSegment.from(""); + TextSegmentConvert.addQueryId(query, id); + queries.add(query); } embeddingService.deleteQuery(presetCollection, queries); } - public void requestEmbeddingPluginAdd(List queries) { + public void requestEmbeddingPluginAdd(List queries) { if (CollectionUtils.isEmpty(queries)) { return; } @@ -133,10 +134,6 @@ public class PluginManager { embeddingService.addQuery(presetCollection, queries); } - public void requestEmbeddingPluginAddALL(List plugins) { - requestEmbeddingPluginAdd(convert(plugins)); - } - public RetrieveQueryResult recognize(String embeddingText) { RetrieveQuery retrieveQuery = RetrieveQuery.builder() @@ -158,15 +155,14 @@ public class PluginManager { throw new RuntimeException("get embedding result failed"); } - public List convert(List plugins) { - List queries = Lists.newArrayList(); + public List convert(List plugins) { + List queries = Lists.newArrayList(); for (Plugin plugin : plugins) { List exampleQuestions = plugin.getExampleQuestionList(); int num = 0; for (String pattern : exampleQuestions) { - EmbeddingQuery query = new EmbeddingQuery(); - query.setQueryId(generateUniqueEmbeddingId(num, plugin.getId())); - query.setQuery(pattern); + TextSegment query = TextSegment.from(pattern); + TextSegmentConvert.addQueryId(query, generateUniqueEmbeddingId(num, plugin.getId())); queries.add(query); num++; } @@ -176,8 +172,8 @@ public class PluginManager { private Set getEmbeddingId(List plugins) { Set embeddingIdSet = new HashSet<>(); - for (EmbeddingQuery query : convert(plugins)) { - embeddingIdSet.add(query.getQueryId()); + for (TextSegment query : convert(plugins)) { + TextSegmentConvert.addQueryId(query, TextSegmentConvert.getQueryId(query)); } return embeddingIdSet; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/SimilarQueryManager.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/SimilarQueryManager.java index 1cee59f78..6fb2f41f4 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/SimilarQueryManager.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/SimilarQueryManager.java @@ -5,18 +5,12 @@ import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq; import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.service.EmbeddingService; -import dev.langchain4j.store.embedding.EmbeddingQuery; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQueryResult; -import java.net.URI; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; +import dev.langchain4j.store.embedding.TextSegmentConvert; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; @@ -32,6 +26,15 @@ import org.springframework.stereotype.Component; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriComponentsBuilder; +import java.net.URI; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + @Slf4j @Component public class SimilarQueryManager { @@ -52,15 +55,13 @@ public class SimilarQueryManager { } String queryText = similarQueryReq.getQueryText(); try { - EmbeddingQuery embeddingQuery = new EmbeddingQuery(); - embeddingQuery.setQueryId(String.valueOf(similarQueryReq.getQueryId())); - embeddingQuery.setQuery(queryText); - Map metaData = new HashMap<>(); metaData.put("agentId", similarQueryReq.getAgentId()); - embeddingQuery.setMetadata(metaData); + TextSegment textSegment = TextSegment.from(queryText, new Metadata(metaData)); + TextSegmentConvert.addQueryId(textSegment, String.valueOf(similarQueryReq.getQueryId())); + String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection(); - embeddingService.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery)); + embeddingService.addQuery(solvedQueryCollection, Lists.newArrayList(textSegment)); } catch (Exception e) { log.warn("save history question to embedding failed, queryText:{}", queryText, e); } diff --git a/common/src/main/java/com/tencent/supersonic/common/service/EmbeddingService.java b/common/src/main/java/com/tencent/supersonic/common/service/EmbeddingService.java index 2b8085aa6..5e34bea30 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/EmbeddingService.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/EmbeddingService.java @@ -1,8 +1,9 @@ package com.tencent.supersonic.common.service; -import dev.langchain4j.store.embedding.EmbeddingQuery; +import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQueryResult; + import java.util.List; /** @@ -13,9 +14,9 @@ public interface EmbeddingService { void addCollection(String collectionName); - void addQuery(String collectionName, List queries); + void addQuery(String collectionName, List queries); - void deleteQuery(String collectionName, List queries); + void deleteQuery(String collectionName, List queries); List retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num); diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index 47292b406..6b6c1e47b 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -2,9 +2,9 @@ package com.tencent.supersonic.common.service.impl; import com.tencent.supersonic.common.service.EmbeddingService; import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.store.embedding.EmbeddingMatch; -import dev.langchain4j.store.embedding.EmbeddingQuery; import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingStore; @@ -12,8 +12,13 @@ import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQueryResult; +import dev.langchain4j.store.embedding.TextSegmentConvert; import dev.langchain4j.store.embedding.filter.Filter; import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; +import org.apache.commons.collections.MapUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; @@ -21,9 +26,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; -import org.apache.commons.collections.MapUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Service; @Service public class EmbeddingServiceImpl implements EmbeddingService { @@ -39,17 +41,17 @@ public class EmbeddingServiceImpl implements EmbeddingService { } @Override - public void addQuery(String collectionName, List queries) { + public void addQuery(String collectionName, List queries) { EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); - for (EmbeddingQuery query : queries) { - String question = query.getQuery(); + for (TextSegment query : queries) { + String question = query.text(); Embedding embedding = embeddingModel.embed(question).content(); embeddingStore.add(embedding, query); } } @Override - public void deleteQuery(String collectionName, List queries) { + public void deleteQuery(String collectionName, List queries) { } @Override @@ -66,21 +68,21 @@ public class EmbeddingServiceImpl implements EmbeddingService { .queryEmbedding(embeddedText).filter(filter).maxResults(num).build(); EmbeddingSearchResult result = embeddingStore.search(request); - List> relevant = result.matches(); + List> relevant = result.matches(); RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult(); retrieveQueryResult.setQuery(queryText); List retrievals = new ArrayList<>(); - for (EmbeddingMatch embeddingMatch : relevant) { + for (EmbeddingMatch embeddingMatch : relevant) { Retrieval retrieval = new Retrieval(); - EmbeddingQuery embedded = embeddingMatch.embedded(); + TextSegment embedded = embeddingMatch.embedded(); retrieval.setDistance(1 - embeddingMatch.score()); - retrieval.setId(embedded.getQueryId()); - retrieval.setQuery(embedded.getQuery()); + retrieval.setId(TextSegmentConvert.getQueryId(embedded)); + retrieval.setQuery(embedded.text()); Map metadata = new HashMap<>(); if (Objects.nonNull(embedded) - && MapUtils.isNotEmpty(embedded.getMetadata())) { - metadata.putAll(embedded.getMetadata()); + && MapUtils.isNotEmpty(embedded.metadata().toMap())) { + metadata.putAll(embedded.metadata().toMap()); } retrieval.setMetadata(metadata); retrievals.add(retrieval); diff --git a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaAutoConfig.java b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaAutoConfig.java index 86145bdf8..276966088 100644 --- a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaAutoConfig.java +++ b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaAutoConfig.java @@ -1,21 +1,21 @@ package dev.langchain4j.chroma.spring; -import static dev.langchain4j.chroma.spring.Properties.PREFIX; - import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import static dev.langchain4j.chroma.spring.Properties.PREFIX; + @Configuration @EnableConfigurationProperties(Properties.class) public class ChromaAutoConfig { @Bean @ConditionalOnProperty(PREFIX + ".embedding-store.base-url") - EmbeddingStoreFactory milvusChatModel(Properties properties) { + EmbeddingStoreFactory chromaChatModel(Properties properties) { return new ChromaEmbeddingStoreFactory(properties); } } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java index d4bf44d0c..5b30510f1 100644 --- a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java @@ -17,7 +17,7 @@ public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory { EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore(); return ChromaEmbeddingStore.builder() .baseUrl(embeddingStore.getBaseUrl()) - .collectionName(embeddingStore.getCollectionName()) + .collectionName(collectionName) .timeout(embeddingStore.getTimeout()) .build(); } diff --git a/common/src/main/java/dev/langchain4j/chroma/spring/EmbeddingStoreProperties.java b/common/src/main/java/dev/langchain4j/chroma/spring/EmbeddingStoreProperties.java index d9fb880ba..0603dcdaa 100644 --- a/common/src/main/java/dev/langchain4j/chroma/spring/EmbeddingStoreProperties.java +++ b/common/src/main/java/dev/langchain4j/chroma/spring/EmbeddingStoreProperties.java @@ -1,14 +1,14 @@ package dev.langchain4j.chroma.spring; -import java.time.Duration; import lombok.Getter; import lombok.Setter; +import java.time.Duration; + @Getter @Setter class EmbeddingStoreProperties { private String baseUrl; - private String collectionName; private Duration timeout; } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java index ff3ec4ee9..56edefc6e 100644 --- a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java @@ -1,18 +1,19 @@ package dev.langchain4j.inmemory.spring; -import dev.langchain4j.store.embedding.EmbeddingQuery; +import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; +import lombok.extern.slf4j.Slf4j; + import java.util.Map; import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; -import lombok.extern.slf4j.Slf4j; @Slf4j public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory { - private static Map> collectionNameToStore = + private static Map> collectionNameToStore = new ConcurrentHashMap<>(); private Properties properties; @@ -23,14 +24,12 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory { @Override public synchronized EmbeddingStore create(String collectionName) { - InMemoryEmbeddingStore embeddingStore = collectionNameToStore.get(collectionName); + InMemoryEmbeddingStore embeddingStore = collectionNameToStore.get(collectionName); if (Objects.nonNull(embeddingStore)) { return embeddingStore; } - if (Objects.isNull(embeddingStore)) { - embeddingStore = new InMemoryEmbeddingStore(); - collectionNameToStore.putIfAbsent(collectionName, embeddingStore); - } + embeddingStore = new InMemoryEmbeddingStore(); + collectionNameToStore.putIfAbsent(collectionName, embeddingStore); return embeddingStore; } diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingQuery.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingQuery.java deleted file mode 100644 index 2da328dda..000000000 --- a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingQuery.java +++ /dev/null @@ -1,35 +0,0 @@ -package dev.langchain4j.store.embedding; - - -import com.alibaba.fastjson.JSONObject; -import com.tencent.supersonic.common.pojo.DataItem; -import lombok.Data; - -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -@Data -public class EmbeddingQuery { - - private String queryId; - - private String query; - - private Map metadata; - - private List queryEmbedding; - public static List convertToEmbedding(List dataItems) { - return dataItems.stream().map(dataItem -> { - EmbeddingQuery embeddingQuery = new EmbeddingQuery(); - embeddingQuery.setQueryId( - dataItem.getId() + dataItem.getType().name().toLowerCase()); - embeddingQuery.setQuery(dataItem.getName()); - Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class); - embeddingQuery.setMetadata(meta); - embeddingQuery.setQueryEmbedding(null); - return embeddingQuery; - }).collect(Collectors.toList()); - } - -} diff --git a/common/src/main/java/dev/langchain4j/store/embedding/TextSegmentConvert.java b/common/src/main/java/dev/langchain4j/store/embedding/TextSegmentConvert.java new file mode 100644 index 000000000..2deb3de40 --- /dev/null +++ b/common/src/main/java/dev/langchain4j/store/embedding/TextSegmentConvert.java @@ -0,0 +1,39 @@ +package dev.langchain4j.store.embedding; + + +import com.alibaba.fastjson.JSONObject; +import com.tencent.supersonic.common.pojo.DataItem; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.segment.TextSegment; +import lombok.Data; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +@Data +public class TextSegmentConvert { + + public static final String QUERY_ID = "queryId"; + + public static List convertToEmbedding(List dataItems) { + return dataItems.stream().map(dataItem -> { + Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class); + TextSegment textSegment = TextSegment.from(dataItem.getName(), new Metadata(meta)); + addQueryId(textSegment, dataItem.getId() + dataItem.getType().name().toLowerCase()); + return textSegment; + }).collect(Collectors.toList()); + } + + public static void addQueryId(TextSegment textSegment, String queryId) { + textSegment.metadata().put(QUERY_ID, queryId); + } + + public static String getQueryId(TextSegment textSegment) { + if (Objects.isNull(textSegment) || Objects.isNull(textSegment.metadata())) { + return null; + } + return textSegment.metadata().get(QUERY_ID); + } +} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ExemplarManager.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ExemplarManager.java index 62db8feef..ea54fe6bd 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ExemplarManager.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ExemplarManager.java @@ -5,10 +5,18 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.util.JsonUtil; -import dev.langchain4j.store.embedding.EmbeddingQuery; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQueryResult; +import dev.langchain4j.store.embedding.TextSegmentConvert; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.core.io.ClassPathResource; +import org.springframework.stereotype.Component; + import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; @@ -17,11 +25,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections.CollectionUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.core.io.ClassPathResource; -import org.springframework.stereotype.Component; @Slf4j @Component @@ -44,15 +47,13 @@ public class ExemplarManager { } public void addExemplars(List exemplars, String collectionName) { - List queries = new ArrayList<>(); + List queries = new ArrayList<>(); for (int i = 0; i < exemplars.size(); i++) { Exemplar exemplar = exemplars.get(i); String question = exemplar.getQuestion(); Map metaDataMap = JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class); - EmbeddingQuery embeddingQuery = new EmbeddingQuery(); - embeddingQuery.setQueryId(String.valueOf(i)); - embeddingQuery.setQuery(question); - embeddingQuery.setMetadata(metaDataMap); + TextSegment embeddingQuery = TextSegment.from(question, new Metadata(metaDataMap)); + TextSegmentConvert.addQueryId(embeddingQuery, String.valueOf(i)); queries.add(embeddingQuery); } embeddingService.addQuery(collectionName, queries); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java index be5111532..04b8ad20b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java @@ -5,8 +5,8 @@ import com.tencent.supersonic.common.pojo.DataEvent; import com.tencent.supersonic.common.pojo.DataItem; import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.common.service.EmbeddingService; -import dev.langchain4j.store.embedding.EmbeddingQuery; -import java.util.List; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.TextSegmentConvert; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -15,6 +15,8 @@ import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; +import java.util.List; + @Component @Slf4j public class MetaEmbeddingListener implements ApplicationListener { @@ -35,19 +37,19 @@ public class MetaEmbeddingListener implements ApplicationListener { if (CollectionUtils.isEmpty(dataItems)) { return; } - List embeddingQueries = EmbeddingQuery.convertToEmbedding(dataItems); - if (CollectionUtils.isEmpty(embeddingQueries)) { + List textSegments = TextSegmentConvert.convertToEmbedding(dataItems); + if (CollectionUtils.isEmpty(textSegments)) { return; } sleep(); embeddingService.addCollection(embeddingConfig.getMetaCollectionName()); if (event.getEventType().equals(EventType.ADD)) { - embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries); + embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), textSegments); } else if (event.getEventType().equals(EventType.DELETE)) { - embeddingService.deleteQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries); + embeddingService.deleteQuery(embeddingConfig.getMetaCollectionName(), textSegments); } else if (event.getEventType().equals(EventType.UPDATE)) { - embeddingService.deleteQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries); - embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries); + embeddingService.deleteQuery(embeddingConfig.getMetaCollectionName(), textSegments); + embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), textSegments); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java index 52246fb72..07c688297 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/schedule/EmbeddingTask.java @@ -5,14 +5,15 @@ import com.tencent.supersonic.common.pojo.DataItem; import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.headless.server.service.DimensionService; import com.tencent.supersonic.headless.server.service.MetricService; -import dev.langchain4j.store.embedding.EmbeddingQuery; -import java.util.List; -import javax.annotation.PreDestroy; +import dev.langchain4j.store.embedding.TextSegmentConvert; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; +import javax.annotation.PreDestroy; +import java.util.List; + @Component @Slf4j public class EmbeddingTask { @@ -55,11 +56,11 @@ public class EmbeddingTask { List metricDataItems = metricService.getDataEvent().getDataItems(); embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), - EmbeddingQuery.convertToEmbedding(metricDataItems)); + TextSegmentConvert.convertToEmbedding(metricDataItems)); List dimensionDataItems = dimensionService.getDataEvent().getDataItems(); embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), - EmbeddingQuery.convertToEmbedding(dimensionDataItems)); + TextSegmentConvert.convertToEmbedding(dimensionDataItems)); } catch (Exception e) { log.error("reload.meta.embedding error", e); }