(improvement)(chat) The embedding model will be uniformly adopted using the textSegment and will be compatible with the queryId parameter. (#1202)

This commit is contained in:
lexluo09
2024-06-24 13:27:03 +08:00
committed by GitHub
parent a7d367baa3
commit 4b288d9815
13 changed files with 134 additions and 127 deletions

View File

@@ -14,16 +14,17 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.service.PluginService; import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.store.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.store.embedding.Retrieval; import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
@@ -116,16 +117,16 @@ public class PluginManager {
} }
String presetCollection = embeddingConfig.getPresetCollection(); String presetCollection = embeddingConfig.getPresetCollection();
List<EmbeddingQuery> queries = new ArrayList<>(); List<TextSegment> queries = new ArrayList<>();
for (String id : queryIds) { for (String id : queryIds) {
EmbeddingQuery embeddingQuery = new EmbeddingQuery(); TextSegment query = TextSegment.from("");
embeddingQuery.setQueryId(id); TextSegmentConvert.addQueryId(query, id);
queries.add(embeddingQuery); queries.add(query);
} }
embeddingService.deleteQuery(presetCollection, queries); embeddingService.deleteQuery(presetCollection, queries);
} }
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) { public void requestEmbeddingPluginAdd(List<TextSegment> queries) {
if (CollectionUtils.isEmpty(queries)) { if (CollectionUtils.isEmpty(queries)) {
return; return;
} }
@@ -133,10 +134,6 @@ public class PluginManager {
embeddingService.addQuery(presetCollection, queries); embeddingService.addQuery(presetCollection, queries);
} }
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
requestEmbeddingPluginAdd(convert(plugins));
}
public RetrieveQueryResult recognize(String embeddingText) { public RetrieveQueryResult recognize(String embeddingText) {
RetrieveQuery retrieveQuery = RetrieveQuery.builder() RetrieveQuery retrieveQuery = RetrieveQuery.builder()
@@ -158,15 +155,14 @@ public class PluginManager {
throw new RuntimeException("get embedding result failed"); throw new RuntimeException("get embedding result failed");
} }
public List<EmbeddingQuery> convert(List<Plugin> plugins) { public List<TextSegment> convert(List<Plugin> plugins) {
List<EmbeddingQuery> queries = Lists.newArrayList(); List<TextSegment> queries = Lists.newArrayList();
for (Plugin plugin : plugins) { for (Plugin plugin : plugins) {
List<String> exampleQuestions = plugin.getExampleQuestionList(); List<String> exampleQuestions = plugin.getExampleQuestionList();
int num = 0; int num = 0;
for (String pattern : exampleQuestions) { for (String pattern : exampleQuestions) {
EmbeddingQuery query = new EmbeddingQuery(); TextSegment query = TextSegment.from(pattern);
query.setQueryId(generateUniqueEmbeddingId(num, plugin.getId())); TextSegmentConvert.addQueryId(query, generateUniqueEmbeddingId(num, plugin.getId()));
query.setQuery(pattern);
queries.add(query); queries.add(query);
num++; num++;
} }
@@ -176,8 +172,8 @@ public class PluginManager {
private Set<String> getEmbeddingId(List<Plugin> plugins) { private Set<String> getEmbeddingId(List<Plugin> plugins) {
Set<String> embeddingIdSet = new HashSet<>(); Set<String> embeddingIdSet = new HashSet<>();
for (EmbeddingQuery query : convert(plugins)) { for (TextSegment query : convert(plugins)) {
embeddingIdSet.add(query.getQueryId()); TextSegmentConvert.addQueryId(query, TextSegmentConvert.getQueryId(query));
} }
return embeddingIdSet; return embeddingIdSet;
} }

View File

@@ -5,18 +5,12 @@ import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp; import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.store.embedding.EmbeddingQuery; import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.RetrieveQueryResult;
import java.net.URI; import dev.langchain4j.store.embedding.TextSegmentConvert;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@@ -32,6 +26,15 @@ import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j @Slf4j
@Component @Component
public class SimilarQueryManager { public class SimilarQueryManager {
@@ -52,15 +55,13 @@ public class SimilarQueryManager {
} }
String queryText = similarQueryReq.getQueryText(); String queryText = similarQueryReq.getQueryText();
try { try {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(String.valueOf(similarQueryReq.getQueryId()));
embeddingQuery.setQuery(queryText);
Map<String, Object> metaData = new HashMap<>(); Map<String, Object> metaData = new HashMap<>();
metaData.put("agentId", similarQueryReq.getAgentId()); metaData.put("agentId", similarQueryReq.getAgentId());
embeddingQuery.setMetadata(metaData); TextSegment textSegment = TextSegment.from(queryText, new Metadata(metaData));
TextSegmentConvert.addQueryId(textSegment, String.valueOf(similarQueryReq.getQueryId()));
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection(); String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
embeddingService.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery)); embeddingService.addQuery(solvedQueryCollection, Lists.newArrayList(textSegment));
} catch (Exception e) { } catch (Exception e) {
log.warn("save history question to embedding failed, queryText:{}", queryText, e); log.warn("save history question to embedding failed, queryText:{}", queryText, e);
} }

View File

@@ -1,8 +1,9 @@
package com.tencent.supersonic.common.service; package com.tencent.supersonic.common.service;
import dev.langchain4j.store.embedding.EmbeddingQuery; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.RetrieveQueryResult;
import java.util.List; import java.util.List;
/** /**
@@ -13,9 +14,9 @@ public interface EmbeddingService {
void addCollection(String collectionName); void addCollection(String collectionName);
void addQuery(String collectionName, List<EmbeddingQuery> queries); void addQuery(String collectionName, List<TextSegment> queries);
void deleteQuery(String collectionName, List<EmbeddingQuery> queries); void deleteQuery(String collectionName, List<TextSegment> queries);
List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num); List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num);

View File

@@ -2,9 +2,9 @@ package com.tencent.supersonic.common.service.impl;
import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingQuery;
import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingSearchResult;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
@@ -12,8 +12,13 @@ import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import dev.langchain4j.store.embedding.filter.Filter; import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
import org.apache.commons.collections.MapUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
@@ -21,9 +26,6 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.commons.collections.MapUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service @Service
public class EmbeddingServiceImpl implements EmbeddingService { public class EmbeddingServiceImpl implements EmbeddingService {
@@ -39,17 +41,17 @@ public class EmbeddingServiceImpl implements EmbeddingService {
} }
@Override @Override
public void addQuery(String collectionName, List<EmbeddingQuery> queries) { public void addQuery(String collectionName, List<TextSegment> queries) {
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName); EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
for (EmbeddingQuery query : queries) { for (TextSegment query : queries) {
String question = query.getQuery(); String question = query.text();
Embedding embedding = embeddingModel.embed(question).content(); Embedding embedding = embeddingModel.embed(question).content();
embeddingStore.add(embedding, query); embeddingStore.add(embedding, query);
} }
} }
@Override @Override
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) { public void deleteQuery(String collectionName, List<TextSegment> queries) {
} }
@Override @Override
@@ -66,21 +68,21 @@ public class EmbeddingServiceImpl implements EmbeddingService {
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build(); .queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
EmbeddingSearchResult result = embeddingStore.search(request); EmbeddingSearchResult result = embeddingStore.search(request);
List<EmbeddingMatch<EmbeddingQuery>> relevant = result.matches(); List<EmbeddingMatch<TextSegment>> relevant = result.matches();
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult(); RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
retrieveQueryResult.setQuery(queryText); retrieveQueryResult.setQuery(queryText);
List<Retrieval> retrievals = new ArrayList<>(); List<Retrieval> retrievals = new ArrayList<>();
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) { for (EmbeddingMatch<TextSegment> embeddingMatch : relevant) {
Retrieval retrieval = new Retrieval(); Retrieval retrieval = new Retrieval();
EmbeddingQuery embedded = embeddingMatch.embedded(); TextSegment embedded = embeddingMatch.embedded();
retrieval.setDistance(1 - embeddingMatch.score()); retrieval.setDistance(1 - embeddingMatch.score());
retrieval.setId(embedded.getQueryId()); retrieval.setId(TextSegmentConvert.getQueryId(embedded));
retrieval.setQuery(embedded.getQuery()); retrieval.setQuery(embedded.text());
Map<String, Object> metadata = new HashMap<>(); Map<String, Object> metadata = new HashMap<>();
if (Objects.nonNull(embedded) if (Objects.nonNull(embedded)
&& MapUtils.isNotEmpty(embedded.getMetadata())) { && MapUtils.isNotEmpty(embedded.metadata().toMap())) {
metadata.putAll(embedded.getMetadata()); metadata.putAll(embedded.metadata().toMap());
} }
retrieval.setMetadata(metadata); retrieval.setMetadata(metadata);
retrievals.add(retrieval); retrievals.add(retrieval);

View File

@@ -1,21 +1,21 @@
package dev.langchain4j.chroma.spring; package dev.langchain4j.chroma.spring;
import static dev.langchain4j.chroma.spring.Properties.PREFIX;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import static dev.langchain4j.chroma.spring.Properties.PREFIX;
@Configuration @Configuration
@EnableConfigurationProperties(Properties.class) @EnableConfigurationProperties(Properties.class)
public class ChromaAutoConfig { public class ChromaAutoConfig {
@Bean @Bean
@ConditionalOnProperty(PREFIX + ".embedding-store.base-url") @ConditionalOnProperty(PREFIX + ".embedding-store.base-url")
EmbeddingStoreFactory milvusChatModel(Properties properties) { EmbeddingStoreFactory chromaChatModel(Properties properties) {
return new ChromaEmbeddingStoreFactory(properties); return new ChromaEmbeddingStoreFactory(properties);
} }
} }

View File

@@ -17,7 +17,7 @@ public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory {
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore(); EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
return ChromaEmbeddingStore.builder() return ChromaEmbeddingStore.builder()
.baseUrl(embeddingStore.getBaseUrl()) .baseUrl(embeddingStore.getBaseUrl())
.collectionName(embeddingStore.getCollectionName()) .collectionName(collectionName)
.timeout(embeddingStore.getTimeout()) .timeout(embeddingStore.getTimeout())
.build(); .build();
} }

View File

@@ -1,14 +1,14 @@
package dev.langchain4j.chroma.spring; package dev.langchain4j.chroma.spring;
import java.time.Duration;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import java.time.Duration;
@Getter @Getter
@Setter @Setter
class EmbeddingStoreProperties { class EmbeddingStoreProperties {
private String baseUrl; private String baseUrl;
private String collectionName;
private Duration timeout; private Duration timeout;
} }

View File

@@ -1,18 +1,19 @@
package dev.langchain4j.inmemory.spring; package dev.langchain4j.inmemory.spring;
import dev.langchain4j.store.embedding.EmbeddingQuery; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.extern.slf4j.Slf4j;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import lombok.extern.slf4j.Slf4j;
@Slf4j @Slf4j
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory { public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore = private static Map<String, InMemoryEmbeddingStore<TextSegment>> collectionNameToStore =
new ConcurrentHashMap<>(); new ConcurrentHashMap<>();
private Properties properties; private Properties properties;
@@ -23,14 +24,12 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
@Override @Override
public synchronized EmbeddingStore create(String collectionName) { public synchronized EmbeddingStore create(String collectionName) {
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = collectionNameToStore.get(collectionName); InMemoryEmbeddingStore<TextSegment> embeddingStore = collectionNameToStore.get(collectionName);
if (Objects.nonNull(embeddingStore)) { if (Objects.nonNull(embeddingStore)) {
return embeddingStore; return embeddingStore;
} }
if (Objects.isNull(embeddingStore)) { embeddingStore = new InMemoryEmbeddingStore();
embeddingStore = new InMemoryEmbeddingStore(); collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
}
return embeddingStore; return embeddingStore;
} }

View File

@@ -1,35 +0,0 @@
package dev.langchain4j.store.embedding;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.pojo.DataItem;
import lombok.Data;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Data
public class EmbeddingQuery {
private String queryId;
private String query;
private Map<String, Object> metadata;
private List<Double> queryEmbedding;
public static List<EmbeddingQuery> convertToEmbedding(List<DataItem> dataItems) {
return dataItems.stream().map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(
dataItem.getId() + dataItem.getType().name().toLowerCase());
embeddingQuery.setQuery(dataItem.getName());
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
embeddingQuery.setMetadata(meta);
embeddingQuery.setQueryEmbedding(null);
return embeddingQuery;
}).collect(Collectors.toList());
}
}

View File

@@ -0,0 +1,39 @@
package dev.langchain4j.store.embedding;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.pojo.DataItem;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.segment.TextSegment;
import lombok.Data;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
@Data
public class TextSegmentConvert {
public static final String QUERY_ID = "queryId";
public static List<TextSegment> convertToEmbedding(List<DataItem> dataItems) {
return dataItems.stream().map(dataItem -> {
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
TextSegment textSegment = TextSegment.from(dataItem.getName(), new Metadata(meta));
addQueryId(textSegment, dataItem.getId() + dataItem.getType().name().toLowerCase());
return textSegment;
}).collect(Collectors.toList());
}
public static void addQueryId(TextSegment textSegment, String queryId) {
textSegment.metadata().put(QUERY_ID, queryId);
}
public static String getQueryId(TextSegment textSegment) {
if (Objects.isNull(textSegment) || Objects.isNull(textSegment.metadata())) {
return null;
}
return textSegment.metadata().get(QUERY_ID);
}
}

View File

@@ -5,10 +5,18 @@ import com.fasterxml.jackson.core.type.TypeReference;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import dev.langchain4j.store.embedding.EmbeddingQuery; import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.ArrayList; import java.util.ArrayList;
@@ -17,11 +25,6 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component;
@Slf4j @Slf4j
@Component @Component
@@ -44,15 +47,13 @@ public class ExemplarManager {
} }
public void addExemplars(List<Exemplar> exemplars, String collectionName) { public void addExemplars(List<Exemplar> exemplars, String collectionName) {
List<EmbeddingQuery> queries = new ArrayList<>(); List<TextSegment> queries = new ArrayList<>();
for (int i = 0; i < exemplars.size(); i++) { for (int i = 0; i < exemplars.size(); i++) {
Exemplar exemplar = exemplars.get(i); Exemplar exemplar = exemplars.get(i);
String question = exemplar.getQuestion(); String question = exemplar.getQuestion();
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class); Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class);
EmbeddingQuery embeddingQuery = new EmbeddingQuery(); TextSegment embeddingQuery = TextSegment.from(question, new Metadata(metaDataMap));
embeddingQuery.setQueryId(String.valueOf(i)); TextSegmentConvert.addQueryId(embeddingQuery, String.valueOf(i));
embeddingQuery.setQuery(question);
embeddingQuery.setMetadata(metaDataMap);
queries.add(embeddingQuery); queries.add(embeddingQuery);
} }
embeddingService.addQuery(collectionName, queries); embeddingService.addQuery(collectionName, queries);

View File

@@ -5,8 +5,8 @@ import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.DataItem; import com.tencent.supersonic.common.pojo.DataItem;
import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.store.embedding.EmbeddingQuery; import dev.langchain4j.data.segment.TextSegment;
import java.util.List; import dev.langchain4j.store.embedding.TextSegmentConvert;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
@@ -15,6 +15,8 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List;
@Component @Component
@Slf4j @Slf4j
public class MetaEmbeddingListener implements ApplicationListener<DataEvent> { public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
@@ -35,19 +37,19 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
if (CollectionUtils.isEmpty(dataItems)) { if (CollectionUtils.isEmpty(dataItems)) {
return; return;
} }
List<EmbeddingQuery> embeddingQueries = EmbeddingQuery.convertToEmbedding(dataItems); List<TextSegment> textSegments = TextSegmentConvert.convertToEmbedding(dataItems);
if (CollectionUtils.isEmpty(embeddingQueries)) { if (CollectionUtils.isEmpty(textSegments)) {
return; return;
} }
sleep(); sleep();
embeddingService.addCollection(embeddingConfig.getMetaCollectionName()); embeddingService.addCollection(embeddingConfig.getMetaCollectionName());
if (event.getEventType().equals(EventType.ADD)) { if (event.getEventType().equals(EventType.ADD)) {
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries); embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), textSegments);
} else if (event.getEventType().equals(EventType.DELETE)) { } else if (event.getEventType().equals(EventType.DELETE)) {
embeddingService.deleteQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries); embeddingService.deleteQuery(embeddingConfig.getMetaCollectionName(), textSegments);
} else if (event.getEventType().equals(EventType.UPDATE)) { } else if (event.getEventType().equals(EventType.UPDATE)) {
embeddingService.deleteQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries); embeddingService.deleteQuery(embeddingConfig.getMetaCollectionName(), textSegments);
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries); embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), textSegments);
} }
} }

View File

@@ -5,14 +5,15 @@ import com.tencent.supersonic.common.pojo.DataItem;
import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.headless.server.service.DimensionService; import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.MetricService; import com.tencent.supersonic.headless.server.service.MetricService;
import dev.langchain4j.store.embedding.EmbeddingQuery; import dev.langchain4j.store.embedding.TextSegmentConvert;
import java.util.List;
import javax.annotation.PreDestroy;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled; import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import javax.annotation.PreDestroy;
import java.util.List;
@Component @Component
@Slf4j @Slf4j
public class EmbeddingTask { public class EmbeddingTask {
@@ -55,11 +56,11 @@ public class EmbeddingTask {
List<DataItem> metricDataItems = metricService.getDataEvent().getDataItems(); List<DataItem> metricDataItems = metricService.getDataEvent().getDataItems();
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), embeddingService.addQuery(embeddingConfig.getMetaCollectionName(),
EmbeddingQuery.convertToEmbedding(metricDataItems)); TextSegmentConvert.convertToEmbedding(metricDataItems));
List<DataItem> dimensionDataItems = dimensionService.getDataEvent().getDataItems(); List<DataItem> dimensionDataItems = dimensionService.getDataEvent().getDataItems();
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), embeddingService.addQuery(embeddingConfig.getMetaCollectionName(),
EmbeddingQuery.convertToEmbedding(dimensionDataItems)); TextSegmentConvert.convertToEmbedding(dimensionDataItems));
} catch (Exception e) { } catch (Exception e) {
log.error("reload.meta.embedding error", e); log.error("reload.meta.embedding error", e);
} }