mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 12:37:55 +00:00
(improvement)(chat) The embedding model will be uniformly adopted using the textSegment and will be compatible with the queryId parameter. (#1202)
This commit is contained in:
@@ -14,16 +14,17 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
|||||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
|
||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
import dev.langchain4j.store.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
|
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
@@ -116,16 +117,16 @@ public class PluginManager {
|
|||||||
}
|
}
|
||||||
String presetCollection = embeddingConfig.getPresetCollection();
|
String presetCollection = embeddingConfig.getPresetCollection();
|
||||||
|
|
||||||
List<EmbeddingQuery> queries = new ArrayList<>();
|
List<TextSegment> queries = new ArrayList<>();
|
||||||
for (String id : queryIds) {
|
for (String id : queryIds) {
|
||||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
TextSegment query = TextSegment.from("");
|
||||||
embeddingQuery.setQueryId(id);
|
TextSegmentConvert.addQueryId(query, id);
|
||||||
queries.add(embeddingQuery);
|
queries.add(query);
|
||||||
}
|
}
|
||||||
embeddingService.deleteQuery(presetCollection, queries);
|
embeddingService.deleteQuery(presetCollection, queries);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
|
public void requestEmbeddingPluginAdd(List<TextSegment> queries) {
|
||||||
if (CollectionUtils.isEmpty(queries)) {
|
if (CollectionUtils.isEmpty(queries)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -133,10 +134,6 @@ public class PluginManager {
|
|||||||
embeddingService.addQuery(presetCollection, queries);
|
embeddingService.addQuery(presetCollection, queries);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
|
|
||||||
requestEmbeddingPluginAdd(convert(plugins));
|
|
||||||
}
|
|
||||||
|
|
||||||
public RetrieveQueryResult recognize(String embeddingText) {
|
public RetrieveQueryResult recognize(String embeddingText) {
|
||||||
|
|
||||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||||
@@ -158,15 +155,14 @@ public class PluginManager {
|
|||||||
throw new RuntimeException("get embedding result failed");
|
throw new RuntimeException("get embedding result failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<EmbeddingQuery> convert(List<Plugin> plugins) {
|
public List<TextSegment> convert(List<Plugin> plugins) {
|
||||||
List<EmbeddingQuery> queries = Lists.newArrayList();
|
List<TextSegment> queries = Lists.newArrayList();
|
||||||
for (Plugin plugin : plugins) {
|
for (Plugin plugin : plugins) {
|
||||||
List<String> exampleQuestions = plugin.getExampleQuestionList();
|
List<String> exampleQuestions = plugin.getExampleQuestionList();
|
||||||
int num = 0;
|
int num = 0;
|
||||||
for (String pattern : exampleQuestions) {
|
for (String pattern : exampleQuestions) {
|
||||||
EmbeddingQuery query = new EmbeddingQuery();
|
TextSegment query = TextSegment.from(pattern);
|
||||||
query.setQueryId(generateUniqueEmbeddingId(num, plugin.getId()));
|
TextSegmentConvert.addQueryId(query, generateUniqueEmbeddingId(num, plugin.getId()));
|
||||||
query.setQuery(pattern);
|
|
||||||
queries.add(query);
|
queries.add(query);
|
||||||
num++;
|
num++;
|
||||||
}
|
}
|
||||||
@@ -176,8 +172,8 @@ public class PluginManager {
|
|||||||
|
|
||||||
private Set<String> getEmbeddingId(List<Plugin> plugins) {
|
private Set<String> getEmbeddingId(List<Plugin> plugins) {
|
||||||
Set<String> embeddingIdSet = new HashSet<>();
|
Set<String> embeddingIdSet = new HashSet<>();
|
||||||
for (EmbeddingQuery query : convert(plugins)) {
|
for (TextSegment query : convert(plugins)) {
|
||||||
embeddingIdSet.add(query.getQueryId());
|
TextSegmentConvert.addQueryId(query, TextSegmentConvert.getQueryId(query));
|
||||||
}
|
}
|
||||||
return embeddingIdSet;
|
return embeddingIdSet;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,18 +5,12 @@ import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
import dev.langchain4j.data.document.Metadata;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.Retrieval;
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
import java.net.URI;
|
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -32,6 +26,15 @@ import org.springframework.stereotype.Component;
|
|||||||
import org.springframework.web.client.RestTemplate;
|
import org.springframework.web.client.RestTemplate;
|
||||||
import org.springframework.web.util.UriComponentsBuilder;
|
import org.springframework.web.util.UriComponentsBuilder;
|
||||||
|
|
||||||
|
import java.net.URI;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Component
|
@Component
|
||||||
public class SimilarQueryManager {
|
public class SimilarQueryManager {
|
||||||
@@ -52,15 +55,13 @@ public class SimilarQueryManager {
|
|||||||
}
|
}
|
||||||
String queryText = similarQueryReq.getQueryText();
|
String queryText = similarQueryReq.getQueryText();
|
||||||
try {
|
try {
|
||||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
|
||||||
embeddingQuery.setQueryId(String.valueOf(similarQueryReq.getQueryId()));
|
|
||||||
embeddingQuery.setQuery(queryText);
|
|
||||||
|
|
||||||
Map<String, Object> metaData = new HashMap<>();
|
Map<String, Object> metaData = new HashMap<>();
|
||||||
metaData.put("agentId", similarQueryReq.getAgentId());
|
metaData.put("agentId", similarQueryReq.getAgentId());
|
||||||
embeddingQuery.setMetadata(metaData);
|
TextSegment textSegment = TextSegment.from(queryText, new Metadata(metaData));
|
||||||
|
TextSegmentConvert.addQueryId(textSegment, String.valueOf(similarQueryReq.getQueryId()));
|
||||||
|
|
||||||
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||||
embeddingService.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
|
embeddingService.addQuery(solvedQueryCollection, Lists.newArrayList(textSegment));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
|
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
package com.tencent.supersonic.common.service;
|
package com.tencent.supersonic.common.service;
|
||||||
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -13,9 +14,9 @@ public interface EmbeddingService {
|
|||||||
|
|
||||||
void addCollection(String collectionName);
|
void addCollection(String collectionName);
|
||||||
|
|
||||||
void addQuery(String collectionName, List<EmbeddingQuery> queries);
|
void addQuery(String collectionName, List<TextSegment> queries);
|
||||||
|
|
||||||
void deleteQuery(String collectionName, List<EmbeddingQuery> queries);
|
void deleteQuery(String collectionName, List<TextSegment> queries);
|
||||||
|
|
||||||
List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num);
|
List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num);
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ package com.tencent.supersonic.common.service.impl;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
import dev.langchain4j.data.embedding.Embedding;
|
import dev.langchain4j.data.embedding.Embedding;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
@@ -12,8 +12,13 @@ import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
|||||||
import dev.langchain4j.store.embedding.Retrieval;
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
|
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||||
import dev.langchain4j.store.embedding.filter.Filter;
|
import dev.langchain4j.store.embedding.filter.Filter;
|
||||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||||
|
import org.apache.commons.collections.MapUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -21,9 +26,6 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import org.apache.commons.collections.MapUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
public class EmbeddingServiceImpl implements EmbeddingService {
|
public class EmbeddingServiceImpl implements EmbeddingService {
|
||||||
@@ -39,17 +41,17 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
|
public void addQuery(String collectionName, List<TextSegment> queries) {
|
||||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||||
for (EmbeddingQuery query : queries) {
|
for (TextSegment query : queries) {
|
||||||
String question = query.getQuery();
|
String question = query.text();
|
||||||
Embedding embedding = embeddingModel.embed(question).content();
|
Embedding embedding = embeddingModel.embed(question).content();
|
||||||
embeddingStore.add(embedding, query);
|
embeddingStore.add(embedding, query);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
|
public void deleteQuery(String collectionName, List<TextSegment> queries) {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -66,21 +68,21 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
|||||||
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
|
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
|
||||||
|
|
||||||
EmbeddingSearchResult result = embeddingStore.search(request);
|
EmbeddingSearchResult result = embeddingStore.search(request);
|
||||||
List<EmbeddingMatch<EmbeddingQuery>> relevant = result.matches();
|
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
||||||
|
|
||||||
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
||||||
retrieveQueryResult.setQuery(queryText);
|
retrieveQueryResult.setQuery(queryText);
|
||||||
List<Retrieval> retrievals = new ArrayList<>();
|
List<Retrieval> retrievals = new ArrayList<>();
|
||||||
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
|
for (EmbeddingMatch<TextSegment> embeddingMatch : relevant) {
|
||||||
Retrieval retrieval = new Retrieval();
|
Retrieval retrieval = new Retrieval();
|
||||||
EmbeddingQuery embedded = embeddingMatch.embedded();
|
TextSegment embedded = embeddingMatch.embedded();
|
||||||
retrieval.setDistance(1 - embeddingMatch.score());
|
retrieval.setDistance(1 - embeddingMatch.score());
|
||||||
retrieval.setId(embedded.getQueryId());
|
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
|
||||||
retrieval.setQuery(embedded.getQuery());
|
retrieval.setQuery(embedded.text());
|
||||||
Map<String, Object> metadata = new HashMap<>();
|
Map<String, Object> metadata = new HashMap<>();
|
||||||
if (Objects.nonNull(embedded)
|
if (Objects.nonNull(embedded)
|
||||||
&& MapUtils.isNotEmpty(embedded.getMetadata())) {
|
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
|
||||||
metadata.putAll(embedded.getMetadata());
|
metadata.putAll(embedded.metadata().toMap());
|
||||||
}
|
}
|
||||||
retrieval.setMetadata(metadata);
|
retrieval.setMetadata(metadata);
|
||||||
retrievals.add(retrieval);
|
retrievals.add(retrieval);
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
package dev.langchain4j.chroma.spring;
|
package dev.langchain4j.chroma.spring;
|
||||||
|
|
||||||
|
|
||||||
import static dev.langchain4j.chroma.spring.Properties.PREFIX;
|
|
||||||
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||||
import org.springframework.context.annotation.Bean;
|
import org.springframework.context.annotation.Bean;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
|
||||||
|
import static dev.langchain4j.chroma.spring.Properties.PREFIX;
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
@EnableConfigurationProperties(Properties.class)
|
@EnableConfigurationProperties(Properties.class)
|
||||||
public class ChromaAutoConfig {
|
public class ChromaAutoConfig {
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
@ConditionalOnProperty(PREFIX + ".embedding-store.base-url")
|
@ConditionalOnProperty(PREFIX + ".embedding-store.base-url")
|
||||||
EmbeddingStoreFactory milvusChatModel(Properties properties) {
|
EmbeddingStoreFactory chromaChatModel(Properties properties) {
|
||||||
return new ChromaEmbeddingStoreFactory(properties);
|
return new ChromaEmbeddingStoreFactory(properties);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -17,7 +17,7 @@ public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
|||||||
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
|
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
|
||||||
return ChromaEmbeddingStore.builder()
|
return ChromaEmbeddingStore.builder()
|
||||||
.baseUrl(embeddingStore.getBaseUrl())
|
.baseUrl(embeddingStore.getBaseUrl())
|
||||||
.collectionName(embeddingStore.getCollectionName())
|
.collectionName(collectionName)
|
||||||
.timeout(embeddingStore.getTimeout())
|
.timeout(embeddingStore.getTimeout())
|
||||||
.build();
|
.build();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
package dev.langchain4j.chroma.spring;
|
package dev.langchain4j.chroma.spring;
|
||||||
|
|
||||||
import java.time.Duration;
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
class EmbeddingStoreProperties {
|
class EmbeddingStoreProperties {
|
||||||
|
|
||||||
private String baseUrl;
|
private String baseUrl;
|
||||||
private String collectionName;
|
|
||||||
private Duration timeout;
|
private Duration timeout;
|
||||||
}
|
}
|
||||||
@@ -1,18 +1,19 @@
|
|||||||
package dev.langchain4j.inmemory.spring;
|
package dev.langchain4j.inmemory.spring;
|
||||||
|
|
||||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||||
|
|
||||||
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
|
private static Map<String, InMemoryEmbeddingStore<TextSegment>> collectionNameToStore =
|
||||||
new ConcurrentHashMap<>();
|
new ConcurrentHashMap<>();
|
||||||
private Properties properties;
|
private Properties properties;
|
||||||
|
|
||||||
@@ -23,14 +24,12 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public synchronized EmbeddingStore create(String collectionName) {
|
public synchronized EmbeddingStore create(String collectionName) {
|
||||||
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = collectionNameToStore.get(collectionName);
|
InMemoryEmbeddingStore<TextSegment> embeddingStore = collectionNameToStore.get(collectionName);
|
||||||
if (Objects.nonNull(embeddingStore)) {
|
if (Objects.nonNull(embeddingStore)) {
|
||||||
return embeddingStore;
|
return embeddingStore;
|
||||||
}
|
}
|
||||||
if (Objects.isNull(embeddingStore)) {
|
embeddingStore = new InMemoryEmbeddingStore();
|
||||||
embeddingStore = new InMemoryEmbeddingStore();
|
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
|
||||||
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
|
|
||||||
}
|
|
||||||
return embeddingStore;
|
return embeddingStore;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,35 +0,0 @@
|
|||||||
package dev.langchain4j.store.embedding;
|
|
||||||
|
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
|
||||||
import com.tencent.supersonic.common.pojo.DataItem;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class EmbeddingQuery {
|
|
||||||
|
|
||||||
private String queryId;
|
|
||||||
|
|
||||||
private String query;
|
|
||||||
|
|
||||||
private Map<String, Object> metadata;
|
|
||||||
|
|
||||||
private List<Double> queryEmbedding;
|
|
||||||
public static List<EmbeddingQuery> convertToEmbedding(List<DataItem> dataItems) {
|
|
||||||
return dataItems.stream().map(dataItem -> {
|
|
||||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
|
||||||
embeddingQuery.setQueryId(
|
|
||||||
dataItem.getId() + dataItem.getType().name().toLowerCase());
|
|
||||||
embeddingQuery.setQuery(dataItem.getName());
|
|
||||||
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
|
||||||
embeddingQuery.setMetadata(meta);
|
|
||||||
embeddingQuery.setQueryEmbedding(null);
|
|
||||||
return embeddingQuery;
|
|
||||||
}).collect(Collectors.toList());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
package dev.langchain4j.store.embedding;
|
||||||
|
|
||||||
|
|
||||||
|
import com.alibaba.fastjson.JSONObject;
|
||||||
|
import com.tencent.supersonic.common.pojo.DataItem;
|
||||||
|
import dev.langchain4j.data.document.Metadata;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class TextSegmentConvert {
|
||||||
|
|
||||||
|
public static final String QUERY_ID = "queryId";
|
||||||
|
|
||||||
|
public static List<TextSegment> convertToEmbedding(List<DataItem> dataItems) {
|
||||||
|
return dataItems.stream().map(dataItem -> {
|
||||||
|
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
||||||
|
TextSegment textSegment = TextSegment.from(dataItem.getName(), new Metadata(meta));
|
||||||
|
addQueryId(textSegment, dataItem.getId() + dataItem.getType().name().toLowerCase());
|
||||||
|
return textSegment;
|
||||||
|
}).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void addQueryId(TextSegment textSegment, String queryId) {
|
||||||
|
textSegment.metadata().put(QUERY_ID, queryId);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String getQueryId(TextSegment textSegment) {
|
||||||
|
if (Objects.isNull(textSegment) || Objects.isNull(textSegment.metadata())) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return textSegment.metadata().get(QUERY_ID);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,10 +5,18 @@ import com.fasterxml.jackson.core.type.TypeReference;
|
|||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
import dev.langchain4j.data.document.Metadata;
|
||||||
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.Retrieval;
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||||
|
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.core.io.ClassPathResource;
|
||||||
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -17,11 +25,6 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.core.io.ClassPathResource;
|
|
||||||
import org.springframework.stereotype.Component;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Component
|
@Component
|
||||||
@@ -44,15 +47,13 @@ public class ExemplarManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void addExemplars(List<Exemplar> exemplars, String collectionName) {
|
public void addExemplars(List<Exemplar> exemplars, String collectionName) {
|
||||||
List<EmbeddingQuery> queries = new ArrayList<>();
|
List<TextSegment> queries = new ArrayList<>();
|
||||||
for (int i = 0; i < exemplars.size(); i++) {
|
for (int i = 0; i < exemplars.size(); i++) {
|
||||||
Exemplar exemplar = exemplars.get(i);
|
Exemplar exemplar = exemplars.get(i);
|
||||||
String question = exemplar.getQuestion();
|
String question = exemplar.getQuestion();
|
||||||
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class);
|
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class);
|
||||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
TextSegment embeddingQuery = TextSegment.from(question, new Metadata(metaDataMap));
|
||||||
embeddingQuery.setQueryId(String.valueOf(i));
|
TextSegmentConvert.addQueryId(embeddingQuery, String.valueOf(i));
|
||||||
embeddingQuery.setQuery(question);
|
|
||||||
embeddingQuery.setMetadata(metaDataMap);
|
|
||||||
queries.add(embeddingQuery);
|
queries.add(embeddingQuery);
|
||||||
}
|
}
|
||||||
embeddingService.addQuery(collectionName, queries);
|
embeddingService.addQuery(collectionName, queries);
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ import com.tencent.supersonic.common.pojo.DataEvent;
|
|||||||
import com.tencent.supersonic.common.pojo.DataItem;
|
import com.tencent.supersonic.common.pojo.DataItem;
|
||||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import java.util.List;
|
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
@@ -15,6 +15,8 @@ import org.springframework.scheduling.annotation.Async;
|
|||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
||||||
@@ -35,19 +37,19 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
|||||||
if (CollectionUtils.isEmpty(dataItems)) {
|
if (CollectionUtils.isEmpty(dataItems)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
List<EmbeddingQuery> embeddingQueries = EmbeddingQuery.convertToEmbedding(dataItems);
|
List<TextSegment> textSegments = TextSegmentConvert.convertToEmbedding(dataItems);
|
||||||
if (CollectionUtils.isEmpty(embeddingQueries)) {
|
if (CollectionUtils.isEmpty(textSegments)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
sleep();
|
sleep();
|
||||||
embeddingService.addCollection(embeddingConfig.getMetaCollectionName());
|
embeddingService.addCollection(embeddingConfig.getMetaCollectionName());
|
||||||
if (event.getEventType().equals(EventType.ADD)) {
|
if (event.getEventType().equals(EventType.ADD)) {
|
||||||
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries);
|
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), textSegments);
|
||||||
} else if (event.getEventType().equals(EventType.DELETE)) {
|
} else if (event.getEventType().equals(EventType.DELETE)) {
|
||||||
embeddingService.deleteQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries);
|
embeddingService.deleteQuery(embeddingConfig.getMetaCollectionName(), textSegments);
|
||||||
} else if (event.getEventType().equals(EventType.UPDATE)) {
|
} else if (event.getEventType().equals(EventType.UPDATE)) {
|
||||||
embeddingService.deleteQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries);
|
embeddingService.deleteQuery(embeddingConfig.getMetaCollectionName(), textSegments);
|
||||||
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), embeddingQueries);
|
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), textSegments);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,14 +5,15 @@ import com.tencent.supersonic.common.pojo.DataItem;
|
|||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
import com.tencent.supersonic.headless.server.service.DimensionService;
|
import com.tencent.supersonic.headless.server.service.DimensionService;
|
||||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||||
import java.util.List;
|
|
||||||
import javax.annotation.PreDestroy;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.scheduling.annotation.Scheduled;
|
import org.springframework.scheduling.annotation.Scheduled;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
|
import javax.annotation.PreDestroy;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
@Component
|
@Component
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class EmbeddingTask {
|
public class EmbeddingTask {
|
||||||
@@ -55,11 +56,11 @@ public class EmbeddingTask {
|
|||||||
List<DataItem> metricDataItems = metricService.getDataEvent().getDataItems();
|
List<DataItem> metricDataItems = metricService.getDataEvent().getDataItems();
|
||||||
|
|
||||||
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(),
|
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(),
|
||||||
EmbeddingQuery.convertToEmbedding(metricDataItems));
|
TextSegmentConvert.convertToEmbedding(metricDataItems));
|
||||||
|
|
||||||
List<DataItem> dimensionDataItems = dimensionService.getDataEvent().getDataItems();
|
List<DataItem> dimensionDataItems = dimensionService.getDataEvent().getDataItems();
|
||||||
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(),
|
embeddingService.addQuery(embeddingConfig.getMetaCollectionName(),
|
||||||
EmbeddingQuery.convertToEmbedding(dimensionDataItems));
|
TextSegmentConvert.convertToEmbedding(dimensionDataItems));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("reload.meta.embedding error", e);
|
log.error("reload.meta.embedding error", e);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user