mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
(improvement)(chat) The embedding model will be uniformly adopted using the textSegment and will be compatible with the queryId parameter. (#1202)
This commit is contained in:
@@ -14,16 +14,17 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -116,16 +117,16 @@ public class PluginManager {
|
||||
}
|
||||
String presetCollection = embeddingConfig.getPresetCollection();
|
||||
|
||||
List<EmbeddingQuery> queries = new ArrayList<>();
|
||||
List<TextSegment> queries = new ArrayList<>();
|
||||
for (String id : queryIds) {
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(id);
|
||||
queries.add(embeddingQuery);
|
||||
TextSegment query = TextSegment.from("");
|
||||
TextSegmentConvert.addQueryId(query, id);
|
||||
queries.add(query);
|
||||
}
|
||||
embeddingService.deleteQuery(presetCollection, queries);
|
||||
}
|
||||
|
||||
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
|
||||
public void requestEmbeddingPluginAdd(List<TextSegment> queries) {
|
||||
if (CollectionUtils.isEmpty(queries)) {
|
||||
return;
|
||||
}
|
||||
@@ -133,10 +134,6 @@ public class PluginManager {
|
||||
embeddingService.addQuery(presetCollection, queries);
|
||||
}
|
||||
|
||||
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
|
||||
requestEmbeddingPluginAdd(convert(plugins));
|
||||
}
|
||||
|
||||
public RetrieveQueryResult recognize(String embeddingText) {
|
||||
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||
@@ -158,15 +155,14 @@ public class PluginManager {
|
||||
throw new RuntimeException("get embedding result failed");
|
||||
}
|
||||
|
||||
public List<EmbeddingQuery> convert(List<Plugin> plugins) {
|
||||
List<EmbeddingQuery> queries = Lists.newArrayList();
|
||||
public List<TextSegment> convert(List<Plugin> plugins) {
|
||||
List<TextSegment> queries = Lists.newArrayList();
|
||||
for (Plugin plugin : plugins) {
|
||||
List<String> exampleQuestions = plugin.getExampleQuestionList();
|
||||
int num = 0;
|
||||
for (String pattern : exampleQuestions) {
|
||||
EmbeddingQuery query = new EmbeddingQuery();
|
||||
query.setQueryId(generateUniqueEmbeddingId(num, plugin.getId()));
|
||||
query.setQuery(pattern);
|
||||
TextSegment query = TextSegment.from(pattern);
|
||||
TextSegmentConvert.addQueryId(query, generateUniqueEmbeddingId(num, plugin.getId()));
|
||||
queries.add(query);
|
||||
num++;
|
||||
}
|
||||
@@ -176,8 +172,8 @@ public class PluginManager {
|
||||
|
||||
private Set<String> getEmbeddingId(List<Plugin> plugins) {
|
||||
Set<String> embeddingIdSet = new HashSet<>();
|
||||
for (EmbeddingQuery query : convert(plugins)) {
|
||||
embeddingIdSet.add(query.getQueryId());
|
||||
for (TextSegment query : convert(plugins)) {
|
||||
TextSegmentConvert.addQueryId(query, TextSegmentConvert.getQueryId(query));
|
||||
}
|
||||
return embeddingIdSet;
|
||||
}
|
||||
|
||||
@@ -5,18 +5,12 @@ import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -32,6 +26,15 @@ import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class SimilarQueryManager {
|
||||
@@ -52,15 +55,13 @@ public class SimilarQueryManager {
|
||||
}
|
||||
String queryText = similarQueryReq.getQueryText();
|
||||
try {
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(String.valueOf(similarQueryReq.getQueryId()));
|
||||
embeddingQuery.setQuery(queryText);
|
||||
|
||||
Map<String, Object> metaData = new HashMap<>();
|
||||
metaData.put("agentId", similarQueryReq.getAgentId());
|
||||
embeddingQuery.setMetadata(metaData);
|
||||
TextSegment textSegment = TextSegment.from(queryText, new Metadata(metaData));
|
||||
TextSegmentConvert.addQueryId(textSegment, String.valueOf(similarQueryReq.getQueryId()));
|
||||
|
||||
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||
embeddingService.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
|
||||
embeddingService.addQuery(solvedQueryCollection, Lists.newArrayList(textSegment));
|
||||
} catch (Exception e) {
|
||||
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user