mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 22:25:19 +00:00
[improvement](chat) Unified vector-related interfaces to go through EmbeddingUtils. (#476)
This commit is contained in:
@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.parser.plugin.embedding;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
|
||||
import com.tencent.supersonic.chat.parser.LLMProxy;
|
||||
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
|
||||
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
@@ -12,6 +12,8 @@ import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -40,13 +42,13 @@ public class EmbeddingRecallParser extends PluginParser {
|
||||
@Override
|
||||
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
||||
String text = queryContext.getRequest().getQueryText();
|
||||
List<RecallRetrieval> embeddingRetrievals = embeddingRecall(text);
|
||||
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
return null;
|
||||
}
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
|
||||
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||
if (plugin == null) {
|
||||
continue;
|
||||
@@ -59,7 +61,7 @@ public class EmbeddingRecallParser extends PluginParser {
|
||||
continue;
|
||||
}
|
||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||
double distance = Double.parseDouble(embeddingRetrieval.getDistance());
|
||||
double distance = embeddingRetrieval.getDistance();
|
||||
double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
|
||||
@@ -68,14 +70,15 @@ public class EmbeddingRecallParser extends PluginParser {
|
||||
return null;
|
||||
}
|
||||
|
||||
public List<RecallRetrieval> embeddingRecall(String embeddingText) {
|
||||
public List<Retrieval> embeddingRecall(String embeddingText) {
|
||||
try {
|
||||
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
|
||||
RecallRetrievalResp embeddingResp = pluginManager.recognize(embeddingText);
|
||||
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||
RetrieveQueryResult embeddingResp = pluginManager.recognize(embeddingText);
|
||||
|
||||
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o ->
|
||||
Math.abs(Double.parseDouble(o.getDistance())))).collect(Collectors.toList());
|
||||
Math.abs(o.getDistance()))).collect(Collectors.toList());
|
||||
embeddingResp.setRetrieval(embeddingRetrievals);
|
||||
}
|
||||
return embeddingRetrievals;
|
||||
|
||||
@@ -3,17 +3,14 @@ package com.tencent.supersonic.chat.plugin;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.PluginTool;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.plugin.event.PluginDelEvent;
|
||||
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
|
||||
@@ -21,31 +18,28 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption;
|
||||
import com.tencent.supersonic.chat.query.plugin.WebBase;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.PluginService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.net.URI;
|
||||
import java.util.List;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Set;
|
||||
import java.util.Optional;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.Objects;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@@ -55,9 +49,12 @@ public class PluginManager {
|
||||
|
||||
private RestTemplate restTemplate;
|
||||
|
||||
public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate) {
|
||||
private EmbeddingUtils embeddingUtils;
|
||||
|
||||
public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate, EmbeddingUtils embeddingUtils) {
|
||||
this.embeddingConfig = embeddingConfig;
|
||||
this.restTemplate = restTemplate;
|
||||
this.embeddingUtils = embeddingUtils;
|
||||
}
|
||||
|
||||
public static List<Plugin> getPluginAgentCanSupport(Integer agentId) {
|
||||
@@ -124,96 +121,77 @@ public class PluginManager {
|
||||
}
|
||||
}
|
||||
|
||||
public void requestEmbeddingPluginDelete(Set<String> ids) {
|
||||
if (CollectionUtils.isEmpty(ids)) {
|
||||
public void requestEmbeddingPluginDelete(Set<String> queryIds) {
|
||||
if (CollectionUtils.isEmpty(queryIds)) {
|
||||
return;
|
||||
}
|
||||
doRequest(embeddingConfig.getDeletePath(), JSONObject.toJSONString(ids));
|
||||
String presetCollection = embeddingConfig.getPresetCollection();
|
||||
|
||||
List<EmbeddingQuery> queries = new ArrayList<>();
|
||||
for (String id : queryIds) {
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(id);
|
||||
queries.add(embeddingQuery);
|
||||
}
|
||||
embeddingUtils.deleteQuery(presetCollection, queries);
|
||||
}
|
||||
|
||||
public void requestEmbeddingPluginAdd(List<Map<String, String>> maps) {
|
||||
if (CollectionUtils.isEmpty(maps)) {
|
||||
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
|
||||
if (CollectionUtils.isEmpty(queries)) {
|
||||
return;
|
||||
}
|
||||
doRequest(embeddingConfig.getAddPath(), JSONObject.toJSONString(maps));
|
||||
}
|
||||
|
||||
public ResponseEntity<String> doRequest(String path, String jsonBody) {
|
||||
if (Strings.isEmpty(embeddingConfig.getUrl())) {
|
||||
return ResponseEntity.of(Optional.empty());
|
||||
}
|
||||
String url = embeddingConfig.getUrl() + path;
|
||||
try {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
headers.setLocation(URI.create(url));
|
||||
URI requestUrl = UriComponentsBuilder
|
||||
.fromHttpUrl(url).build().encode().toUri();
|
||||
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
|
||||
log.info("[embedding] equest body :{}, url:{}", jsonBody, url);
|
||||
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
|
||||
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {});
|
||||
log.info("[embedding] result body:{}", responseEntity);
|
||||
return responseEntity;
|
||||
} catch (Throwable e) {
|
||||
log.warn("connect to embedding service failed, url:{}", url);
|
||||
}
|
||||
return ResponseEntity.of(Optional.empty());
|
||||
String presetCollection = embeddingConfig.getPresetCollection();
|
||||
embeddingUtils.addQuery(presetCollection, queries);
|
||||
}
|
||||
|
||||
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
|
||||
requestEmbeddingPluginAdd(convert(plugins));
|
||||
}
|
||||
|
||||
public RecallRetrievalResp recognize(String embeddingText) {
|
||||
String url = embeddingConfig.getUrl() + embeddingConfig.getRecognizePath() + "?n_results="
|
||||
+ embeddingConfig.getNResult();
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
headers.setLocation(URI.create(url));
|
||||
URI requestUrl = UriComponentsBuilder
|
||||
.fromHttpUrl(url).build().encode().toUri();
|
||||
String jsonBody = JSONObject.toJSONString(Lists.newArrayList(embeddingText));
|
||||
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
|
||||
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
|
||||
ResponseEntity<List<RecallRetrievalResp>> embeddingResponseEntity =
|
||||
restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
||||
new ParameterizedTypeReference<List<RecallRetrievalResp>>() {
|
||||
});
|
||||
log.info("[embedding] recognize result body:{}", embeddingResponseEntity);
|
||||
List<RecallRetrievalResp> embeddingResps = embeddingResponseEntity.getBody();
|
||||
if (CollectionUtils.isNotEmpty(embeddingResps)) {
|
||||
for (RecallRetrievalResp embeddingResp : embeddingResps) {
|
||||
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
public RetrieveQueryResult recognize(String embeddingText) {
|
||||
|
||||
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
|
||||
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||
.queryTextsList(Collections.singletonList(embeddingText))
|
||||
.build();
|
||||
|
||||
List<RetrieveQueryResult> resultList = embeddingUtils.retrieveQuery(embeddingConfig.getPresetCollection(),
|
||||
retrieveQuery, embeddingConfig.getNResult());
|
||||
|
||||
|
||||
if (CollectionUtils.isNotEmpty(resultList)) {
|
||||
for (RetrieveQueryResult embeddingResp : resultList) {
|
||||
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
embeddingRetrieval.setId(getPluginIdFromEmbeddingId(embeddingRetrieval.getId()));
|
||||
}
|
||||
}
|
||||
return embeddingResps.get(0);
|
||||
return resultList.get(0);
|
||||
}
|
||||
throw new RuntimeException("get embedding result failed");
|
||||
}
|
||||
|
||||
public List<Map<String, String>> convert(List<Plugin> plugins) {
|
||||
List<Map<String, String>> maps = Lists.newArrayList();
|
||||
public List<EmbeddingQuery> convert(List<Plugin> plugins) {
|
||||
List<EmbeddingQuery> queries = Lists.newArrayList();
|
||||
for (Plugin plugin : plugins) {
|
||||
List<String> exampleQuestions = plugin.getExampleQuestionList();
|
||||
int num = 0;
|
||||
for (String pattern : exampleQuestions) {
|
||||
Map<String, String> map = new HashMap<>();
|
||||
map.put("preset_query_id", generateUniqueEmbeddingId(num, plugin.getId()));
|
||||
map.put("preset_query", pattern);
|
||||
maps.add(map);
|
||||
EmbeddingQuery query = new EmbeddingQuery();
|
||||
query.setQueryId(generateUniqueEmbeddingId(num, plugin.getId()));
|
||||
query.setQuery(pattern);
|
||||
queries.add(query);
|
||||
num++;
|
||||
}
|
||||
}
|
||||
return maps;
|
||||
return queries;
|
||||
}
|
||||
|
||||
private Set<String> getEmbeddingId(List<Plugin> plugins) {
|
||||
Set<String> embeddingIdSet = new HashSet<>();
|
||||
for (Map<String, String> map : convert(plugins)) {
|
||||
embeddingIdSet.add(map.get("preset_query_id"));
|
||||
for (EmbeddingQuery query : convert(plugins)) {
|
||||
embeddingIdSet.add(query.getQueryId());
|
||||
}
|
||||
return embeddingIdSet;
|
||||
}
|
||||
@@ -243,7 +221,7 @@ public class PluginManager {
|
||||
}
|
||||
Set<Long> matchedModel = Sets.newHashSet();
|
||||
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream()
|
||||
.collect(Collectors.groupingBy(ParamOption::getModelId));
|
||||
.collect(Collectors.groupingBy(ParamOption::getModelId));
|
||||
for (Long modelId : paramOptionMap.keySet()) {
|
||||
List<ParamOption> params = paramOptionMap.get(modelId);
|
||||
if (CollectionUtils.isEmpty(params)) {
|
||||
|
||||
@@ -10,16 +10,17 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||
import java.util.HashMap;
|
||||
@@ -30,6 +31,7 @@ import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -149,12 +151,18 @@ public class MetricAnalyzeQuery extends LLMSemanticQuery {
|
||||
}
|
||||
|
||||
public String fetchInterpret(String queryText, String dataText) {
|
||||
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
|
||||
LLMAnswerReq lLmAnswerReq = new LLMAnswerReq();
|
||||
lLmAnswerReq.setQueryText(queryText);
|
||||
lLmAnswerReq.setPluginOutput(dataText);
|
||||
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
|
||||
JSONObject.toJSONString(lLmAnswerReq));
|
||||
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
String metricAnalyzeQueryCollection = embeddingConfig.getMetricAnalyzeQueryCollection();
|
||||
|
||||
String url = String.format("%s/retrieve_query?collection_name=%s", embeddingConfig.getUrl(),
|
||||
metricAnalyzeQueryCollection);
|
||||
|
||||
ResponseEntity<String> responseEntity = embeddingUtils.doRequest(url, JSONObject.toJSONString(lLmAnswerReq),
|
||||
HttpMethod.POST);
|
||||
LLMAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLMAnswerResp.class);
|
||||
if (lLmAnswerResp != null) {
|
||||
return lLmAnswerResp.getAssistantMessage();
|
||||
|
||||
@@ -199,7 +199,7 @@ public class QueryServiceImpl implements QueryService {
|
||||
queryReq.getUser().getName(), queryReq.getChatId().longValue());
|
||||
queryResult.setChatContext(parseInfo);
|
||||
// update chat context after a successful semantic query
|
||||
if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) {
|
||||
if (QueryState.SUCCESS.equals(queryResult.getQueryState())) {
|
||||
chatCtx.setParseInfo(parseInfo);
|
||||
chatService.updateContext(chatCtx);
|
||||
saveSolvedQuery(queryReq, parseInfo, chatQueryDO, queryResult);
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
package com.tencent.supersonic.chat.utils;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -21,13 +29,6 @@ import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@@ -35,8 +36,11 @@ public class SolvedQueryManager {
|
||||
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
public SolvedQueryManager(EmbeddingConfig embeddingConfig) {
|
||||
private EmbeddingUtils embeddingUtils;
|
||||
|
||||
public SolvedQueryManager(EmbeddingConfig embeddingConfig, EmbeddingUtils embeddingUtils) {
|
||||
this.embeddingConfig = embeddingConfig;
|
||||
this.embeddingUtils = embeddingUtils;
|
||||
}
|
||||
|
||||
public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) {
|
||||
@@ -46,15 +50,16 @@ public class SolvedQueryManager {
|
||||
String queryText = solvedQueryReq.getQueryText();
|
||||
try {
|
||||
String uniqueId = generateUniqueId(solvedQueryReq.getQueryId(), solvedQueryReq.getParseId());
|
||||
Map<String, Object> requestMap = new HashMap<>();
|
||||
requestMap.put("query", queryText);
|
||||
requestMap.put("query_id", uniqueId);
|
||||
Map<String, Object> metaData = new HashMap<>();
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(uniqueId);
|
||||
embeddingQuery.setQuery(queryText);
|
||||
|
||||
Map<String, String> metaData = new HashMap<>();
|
||||
metaData.put("modelId", String.valueOf(solvedQueryReq.getModelId()));
|
||||
metaData.put("agentId", String.valueOf(solvedQueryReq.getAgentId()));
|
||||
requestMap.put("metadata", metaData);
|
||||
doRequest(embeddingConfig.getSolvedQueryAddPath(),
|
||||
JSONObject.toJSONString(Lists.newArrayList(requestMap)));
|
||||
embeddingQuery.setMetadata(metaData);
|
||||
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||
embeddingUtils.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
|
||||
} catch (Exception e) {
|
||||
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
|
||||
}
|
||||
@@ -66,49 +71,41 @@ public class SolvedQueryManager {
|
||||
}
|
||||
List<SolvedQueryRecallResp> solvedQueryRecallResps = Lists.newArrayList();
|
||||
try {
|
||||
String url = embeddingConfig.getUrl() + embeddingConfig.getSolvedQueryRecallPath() + "?n_results="
|
||||
+ embeddingConfig.getSolvedQueryResultNum();
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
headers.setLocation(URI.create(url));
|
||||
URI requestUrl = UriComponentsBuilder
|
||||
.fromHttpUrl(url).build().encode().toUri();
|
||||
Map<String, Object> map = new HashMap<>();
|
||||
map.put("queryTextsList", Lists.newArrayList(queryText));
|
||||
Map<String, Object> filterCondition = new HashMap<>();
|
||||
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||
int solvedQueryResultNum = embeddingConfig.getSolvedQueryResultNum();
|
||||
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
filterCondition.put("agentId", String.valueOf(agentId));
|
||||
map.put("filterCondition", filterCondition);
|
||||
String jsonBody = JSONObject.toJSONString(map, SerializerFeature.WriteMapNullValue);
|
||||
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
|
||||
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
|
||||
RestTemplate restTemplate = new RestTemplate();
|
||||
ResponseEntity<List<RecallRetrievalResp>> embeddingResponseEntity =
|
||||
restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
||||
new ParameterizedTypeReference<List<RecallRetrievalResp>>() {
|
||||
});
|
||||
log.info("[embedding] recognize result body:{}", embeddingResponseEntity);
|
||||
List<RecallRetrievalResp> embeddingResps = embeddingResponseEntity.getBody();
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||
.queryTextsList(Lists.newArrayList(queryText))
|
||||
.filterCondition(filterCondition)
|
||||
.build();
|
||||
List<RetrieveQueryResult> resultList = embeddingUtils.retrieveQuery(solvedQueryCollection, retrieveQuery,
|
||||
solvedQueryResultNum);
|
||||
|
||||
log.info("[embedding] recognize result body:{}", resultList);
|
||||
Set<String> querySet = new HashSet<>();
|
||||
if (CollectionUtils.isNotEmpty(embeddingResps)) {
|
||||
for (RecallRetrievalResp embeddingResp : embeddingResps) {
|
||||
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
if (queryText.equalsIgnoreCase(embeddingRetrieval.getQuery())) {
|
||||
if (CollectionUtils.isNotEmpty(resultList)) {
|
||||
for (RetrieveQueryResult retrieveQueryResult : resultList) {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
for (Retrieval retrieval : retrievals) {
|
||||
if (queryText.equalsIgnoreCase(retrieval.getQuery())) {
|
||||
continue;
|
||||
}
|
||||
if (querySet.contains(embeddingRetrieval.getQuery())) {
|
||||
if (querySet.contains(retrieval.getQuery())) {
|
||||
continue;
|
||||
}
|
||||
String id = embeddingRetrieval.getId();
|
||||
String id = retrieval.getId();
|
||||
SolvedQueryRecallResp solvedQueryRecallResp = SolvedQueryRecallResp.builder()
|
||||
.queryText(embeddingRetrieval.getQuery())
|
||||
.queryText(retrieval.getQuery())
|
||||
.queryId(getQueryId(id)).parseId(getParseId(id))
|
||||
.build();
|
||||
solvedQueryRecallResps.add(solvedQueryRecallResp);
|
||||
querySet.add(embeddingRetrieval.getQuery());
|
||||
querySet.add(retrieval.getQuery());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.warn("recall similar solved query failed, queryText:{}", queryText);
|
||||
}
|
||||
@@ -146,7 +143,8 @@ public class SolvedQueryManager {
|
||||
log.info("[embedding] request body :{}, url:{}", jsonBody, url);
|
||||
RestTemplate restTemplate = new RestTemplate();
|
||||
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
|
||||
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {});
|
||||
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {
|
||||
});
|
||||
log.info("[embedding] result body:{}", responseEntity);
|
||||
return responseEntity;
|
||||
} catch (Exception e) {
|
||||
|
||||
Reference in New Issue
Block a user