[improvement](chat) Unified vector-related interfaces to go through EmbeddingUtils. (#476)

This commit is contained in:
lexluo09
2023-12-06 14:50:57 +08:00
committed by GitHub
parent 9aa5c93d9d
commit ed0f856438
9 changed files with 168 additions and 208 deletions

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.parser.plugin.embedding;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
import com.tencent.supersonic.chat.parser.LLMProxy; import com.tencent.supersonic.chat.parser.LLMProxy;
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
import com.tencent.supersonic.chat.parser.plugin.ParseMode; import com.tencent.supersonic.chat.parser.plugin.ParseMode;
import com.tencent.supersonic.chat.parser.plugin.PluginParser; import com.tencent.supersonic.chat.parser.plugin.PluginParser;
import com.tencent.supersonic.chat.plugin.Plugin; import com.tencent.supersonic.chat.plugin.Plugin;
@@ -12,6 +12,8 @@ import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@@ -40,13 +42,13 @@ public class EmbeddingRecallParser extends PluginParser {
@Override @Override
public PluginRecallResult recallPlugin(QueryContext queryContext) { public PluginRecallResult recallPlugin(QueryContext queryContext) {
String text = queryContext.getRequest().getQueryText(); String text = queryContext.getRequest().getQueryText();
List<RecallRetrieval> embeddingRetrievals = embeddingRecall(text); List<Retrieval> embeddingRetrievals = embeddingRecall(text);
if (CollectionUtils.isEmpty(embeddingRetrievals)) { if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return null; return null;
} }
List<Plugin> plugins = getPluginList(queryContext); List<Plugin> plugins = getPluginList(queryContext);
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p)); Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) { for (Retrieval embeddingRetrieval : embeddingRetrievals) {
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId())); Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
if (plugin == null) { if (plugin == null) {
continue; continue;
@@ -59,7 +61,7 @@ public class EmbeddingRecallParser extends PluginParser {
continue; continue;
} }
plugin.setParseMode(ParseMode.EMBEDDING_RECALL); plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double distance = Double.parseDouble(embeddingRetrieval.getDistance()); double distance = embeddingRetrieval.getDistance();
double score = queryContext.getRequest().getQueryText().length() * (1 - distance); double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
return PluginRecallResult.builder() return PluginRecallResult.builder()
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build(); .plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
@@ -68,14 +70,15 @@ public class EmbeddingRecallParser extends PluginParser {
return null; return null;
} }
public List<RecallRetrieval> embeddingRecall(String embeddingText) { public List<Retrieval> embeddingRecall(String embeddingText) {
try { try {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class); PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
RecallRetrievalResp embeddingResp = pluginManager.recognize(embeddingText); RetrieveQueryResult embeddingResp = pluginManager.recognize(embeddingText);
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
if (!CollectionUtils.isEmpty(embeddingRetrievals)) { if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o -> embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o ->
Math.abs(Double.parseDouble(o.getDistance())))).collect(Collectors.toList()); Math.abs(o.getDistance()))).collect(Collectors.toList());
embeddingResp.setRetrieval(embeddingRetrievals); embeddingResp.setRetrieval(embeddingRetrievals);
} }
return embeddingRetrievals; return embeddingRetrievals;

View File

@@ -3,17 +3,14 @@ package com.tencent.supersonic.chat.plugin;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.AgentToolType; import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.PluginTool; import com.tencent.supersonic.chat.agent.PluginTool;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent; import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.plugin.event.PluginDelEvent; import com.tencent.supersonic.chat.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent; import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
@@ -21,31 +18,28 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.WebBase; import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.PluginService; import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI; import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import java.util.List; import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Set; import java.util.Collections;
import java.util.Optional;
import java.util.HashSet; import java.util.HashSet;
import java.util.HashMap; import java.util.List;
import java.util.Objects;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings;
import org.springframework.context.event.EventListener; import org.springframework.context.event.EventListener;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.ResponseEntity;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.HttpEntity;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
@Slf4j @Slf4j
@Component @Component
@@ -55,9 +49,12 @@ public class PluginManager {
private RestTemplate restTemplate; private RestTemplate restTemplate;
public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate) { private EmbeddingUtils embeddingUtils;
public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate, EmbeddingUtils embeddingUtils) {
this.embeddingConfig = embeddingConfig; this.embeddingConfig = embeddingConfig;
this.restTemplate = restTemplate; this.restTemplate = restTemplate;
this.embeddingUtils = embeddingUtils;
} }
public static List<Plugin> getPluginAgentCanSupport(Integer agentId) { public static List<Plugin> getPluginAgentCanSupport(Integer agentId) {
@@ -124,96 +121,77 @@ public class PluginManager {
} }
} }
public void requestEmbeddingPluginDelete(Set<String> ids) { public void requestEmbeddingPluginDelete(Set<String> queryIds) {
if (CollectionUtils.isEmpty(ids)) { if (CollectionUtils.isEmpty(queryIds)) {
return; return;
} }
doRequest(embeddingConfig.getDeletePath(), JSONObject.toJSONString(ids)); String presetCollection = embeddingConfig.getPresetCollection();
List<EmbeddingQuery> queries = new ArrayList<>();
for (String id : queryIds) {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(id);
queries.add(embeddingQuery);
}
embeddingUtils.deleteQuery(presetCollection, queries);
} }
public void requestEmbeddingPluginAdd(List<Map<String, String>> maps) { public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
if (CollectionUtils.isEmpty(maps)) { if (CollectionUtils.isEmpty(queries)) {
return; return;
} }
doRequest(embeddingConfig.getAddPath(), JSONObject.toJSONString(maps)); String presetCollection = embeddingConfig.getPresetCollection();
} embeddingUtils.addQuery(presetCollection, queries);
public ResponseEntity<String> doRequest(String path, String jsonBody) {
if (Strings.isEmpty(embeddingConfig.getUrl())) {
return ResponseEntity.of(Optional.empty());
}
String url = embeddingConfig.getUrl() + path;
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
URI requestUrl = UriComponentsBuilder
.fromHttpUrl(url).build().encode().toUri();
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
log.info("[embedding] equest body :{}, url:{}", jsonBody, url);
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {});
log.info("[embedding] result body:{}", responseEntity);
return responseEntity;
} catch (Throwable e) {
log.warn("connect to embedding service failed, url:{}", url);
}
return ResponseEntity.of(Optional.empty());
} }
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) { public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
requestEmbeddingPluginAdd(convert(plugins)); requestEmbeddingPluginAdd(convert(plugins));
} }
public RecallRetrievalResp recognize(String embeddingText) { public RetrieveQueryResult recognize(String embeddingText) {
String url = embeddingConfig.getUrl() + embeddingConfig.getRecognizePath() + "?n_results="
+ embeddingConfig.getNResult(); EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON); RetrieveQuery retrieveQuery = RetrieveQuery.builder()
headers.setLocation(URI.create(url)); .queryTextsList(Collections.singletonList(embeddingText))
URI requestUrl = UriComponentsBuilder .build();
.fromHttpUrl(url).build().encode().toUri();
String jsonBody = JSONObject.toJSONString(Lists.newArrayList(embeddingText)); List<RetrieveQueryResult> resultList = embeddingUtils.retrieveQuery(embeddingConfig.getPresetCollection(),
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers); retrieveQuery, embeddingConfig.getNResult());
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
ResponseEntity<List<RecallRetrievalResp>> embeddingResponseEntity =
restTemplate.exchange(requestUrl, HttpMethod.POST, entity, if (CollectionUtils.isNotEmpty(resultList)) {
new ParameterizedTypeReference<List<RecallRetrievalResp>>() { for (RetrieveQueryResult embeddingResp : resultList) {
}); List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
log.info("[embedding] recognize result body:{}", embeddingResponseEntity); for (Retrieval embeddingRetrieval : embeddingRetrievals) {
List<RecallRetrievalResp> embeddingResps = embeddingResponseEntity.getBody();
if (CollectionUtils.isNotEmpty(embeddingResps)) {
for (RecallRetrievalResp embeddingResp : embeddingResps) {
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
embeddingRetrieval.setId(getPluginIdFromEmbeddingId(embeddingRetrieval.getId())); embeddingRetrieval.setId(getPluginIdFromEmbeddingId(embeddingRetrieval.getId()));
} }
} }
return embeddingResps.get(0); return resultList.get(0);
} }
throw new RuntimeException("get embedding result failed"); throw new RuntimeException("get embedding result failed");
} }
public List<Map<String, String>> convert(List<Plugin> plugins) { public List<EmbeddingQuery> convert(List<Plugin> plugins) {
List<Map<String, String>> maps = Lists.newArrayList(); List<EmbeddingQuery> queries = Lists.newArrayList();
for (Plugin plugin : plugins) { for (Plugin plugin : plugins) {
List<String> exampleQuestions = plugin.getExampleQuestionList(); List<String> exampleQuestions = plugin.getExampleQuestionList();
int num = 0; int num = 0;
for (String pattern : exampleQuestions) { for (String pattern : exampleQuestions) {
Map<String, String> map = new HashMap<>(); EmbeddingQuery query = new EmbeddingQuery();
map.put("preset_query_id", generateUniqueEmbeddingId(num, plugin.getId())); query.setQueryId(generateUniqueEmbeddingId(num, plugin.getId()));
map.put("preset_query", pattern); query.setQuery(pattern);
maps.add(map); queries.add(query);
num++; num++;
} }
} }
return maps; return queries;
} }
private Set<String> getEmbeddingId(List<Plugin> plugins) { private Set<String> getEmbeddingId(List<Plugin> plugins) {
Set<String> embeddingIdSet = new HashSet<>(); Set<String> embeddingIdSet = new HashSet<>();
for (Map<String, String> map : convert(plugins)) { for (EmbeddingQuery query : convert(plugins)) {
embeddingIdSet.add(map.get("preset_query_id")); embeddingIdSet.add(query.getQueryId());
} }
return embeddingIdSet; return embeddingIdSet;
} }
@@ -243,7 +221,7 @@ public class PluginManager {
} }
Set<Long> matchedModel = Sets.newHashSet(); Set<Long> matchedModel = Sets.newHashSet();
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream() Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream()
.collect(Collectors.groupingBy(ParamOption::getModelId)); .collect(Collectors.groupingBy(ParamOption::getModelId));
for (Long modelId : paramOptionMap.keySet()) { for (Long modelId : paramOptionMap.keySet()) {
List<ParamOption> params = paramOptionMap.get(modelId); List<ParamOption> params = paramOptionMap.get(modelId);
if (CollectionUtils.isEmpty(params)) { if (CollectionUtils.isEmpty(params)) {

View File

@@ -10,16 +10,17 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState; import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.config.OptimizationConfig; import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery; import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder; import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.QueryType; import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.HashMap; import java.util.HashMap;
@@ -30,6 +31,7 @@ import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@@ -149,12 +151,18 @@ public class MetricAnalyzeQuery extends LLMSemanticQuery {
} }
public String fetchInterpret(String queryText, String dataText) { public String fetchInterpret(String queryText, String dataText) {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
LLMAnswerReq lLmAnswerReq = new LLMAnswerReq(); LLMAnswerReq lLmAnswerReq = new LLMAnswerReq();
lLmAnswerReq.setQueryText(queryText); lLmAnswerReq.setQueryText(queryText);
lLmAnswerReq.setPluginOutput(dataText); lLmAnswerReq.setPluginOutput(dataText);
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call", EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
JSONObject.toJSONString(lLmAnswerReq)); EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String metricAnalyzeQueryCollection = embeddingConfig.getMetricAnalyzeQueryCollection();
String url = String.format("%s/retrieve_query?collection_name=%s", embeddingConfig.getUrl(),
metricAnalyzeQueryCollection);
ResponseEntity<String> responseEntity = embeddingUtils.doRequest(url, JSONObject.toJSONString(lLmAnswerReq),
HttpMethod.POST);
LLMAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLMAnswerResp.class); LLMAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLMAnswerResp.class);
if (lLmAnswerResp != null) { if (lLmAnswerResp != null) {
return lLmAnswerResp.getAssistantMessage(); return lLmAnswerResp.getAssistantMessage();

View File

@@ -199,7 +199,7 @@ public class QueryServiceImpl implements QueryService {
queryReq.getUser().getName(), queryReq.getChatId().longValue()); queryReq.getUser().getName(), queryReq.getChatId().longValue());
queryResult.setChatContext(parseInfo); queryResult.setChatContext(parseInfo);
// update chat context after a successful semantic query // update chat context after a successful semantic query
if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) { if (QueryState.SUCCESS.equals(queryResult.getQueryState())) {
chatCtx.setParseInfo(parseInfo); chatCtx.setParseInfo(parseInfo);
chatService.updateContext(chatCtx); chatService.updateContext(chatCtx);
saveSolvedQuery(queryReq, parseInfo, chatQueryDO, queryResult); saveSolvedQuery(queryReq, parseInfo, chatQueryDO, queryResult);

View File

@@ -1,13 +1,21 @@
package com.tencent.supersonic.chat.utils; package com.tencent.supersonic.chat.utils;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq; import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp; import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval; import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@@ -21,13 +29,6 @@ import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
@Slf4j @Slf4j
@Component @Component
@@ -35,8 +36,11 @@ public class SolvedQueryManager {
private EmbeddingConfig embeddingConfig; private EmbeddingConfig embeddingConfig;
public SolvedQueryManager(EmbeddingConfig embeddingConfig) { private EmbeddingUtils embeddingUtils;
public SolvedQueryManager(EmbeddingConfig embeddingConfig, EmbeddingUtils embeddingUtils) {
this.embeddingConfig = embeddingConfig; this.embeddingConfig = embeddingConfig;
this.embeddingUtils = embeddingUtils;
} }
public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) { public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) {
@@ -46,15 +50,16 @@ public class SolvedQueryManager {
String queryText = solvedQueryReq.getQueryText(); String queryText = solvedQueryReq.getQueryText();
try { try {
String uniqueId = generateUniqueId(solvedQueryReq.getQueryId(), solvedQueryReq.getParseId()); String uniqueId = generateUniqueId(solvedQueryReq.getQueryId(), solvedQueryReq.getParseId());
Map<String, Object> requestMap = new HashMap<>(); EmbeddingQuery embeddingQuery = new EmbeddingQuery();
requestMap.put("query", queryText); embeddingQuery.setQueryId(uniqueId);
requestMap.put("query_id", uniqueId); embeddingQuery.setQuery(queryText);
Map<String, Object> metaData = new HashMap<>();
Map<String, String> metaData = new HashMap<>();
metaData.put("modelId", String.valueOf(solvedQueryReq.getModelId())); metaData.put("modelId", String.valueOf(solvedQueryReq.getModelId()));
metaData.put("agentId", String.valueOf(solvedQueryReq.getAgentId())); metaData.put("agentId", String.valueOf(solvedQueryReq.getAgentId()));
requestMap.put("metadata", metaData); embeddingQuery.setMetadata(metaData);
doRequest(embeddingConfig.getSolvedQueryAddPath(), String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
JSONObject.toJSONString(Lists.newArrayList(requestMap))); embeddingUtils.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
} catch (Exception e) { } catch (Exception e) {
log.warn("save history question to embedding failed, queryText:{}", queryText, e); log.warn("save history question to embedding failed, queryText:{}", queryText, e);
} }
@@ -66,49 +71,41 @@ public class SolvedQueryManager {
} }
List<SolvedQueryRecallResp> solvedQueryRecallResps = Lists.newArrayList(); List<SolvedQueryRecallResp> solvedQueryRecallResps = Lists.newArrayList();
try { try {
String url = embeddingConfig.getUrl() + embeddingConfig.getSolvedQueryRecallPath() + "?n_results=" String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
+ embeddingConfig.getSolvedQueryResultNum(); int solvedQueryResultNum = embeddingConfig.getSolvedQueryResultNum();
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON); Map<String, String> filterCondition = new HashMap<>();
headers.setLocation(URI.create(url));
URI requestUrl = UriComponentsBuilder
.fromHttpUrl(url).build().encode().toUri();
Map<String, Object> map = new HashMap<>();
map.put("queryTextsList", Lists.newArrayList(queryText));
Map<String, Object> filterCondition = new HashMap<>();
filterCondition.put("agentId", String.valueOf(agentId)); filterCondition.put("agentId", String.valueOf(agentId));
map.put("filterCondition", filterCondition); RetrieveQuery retrieveQuery = RetrieveQuery.builder()
String jsonBody = JSONObject.toJSONString(map, SerializerFeature.WriteMapNullValue); .queryTextsList(Lists.newArrayList(queryText))
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers); .filterCondition(filterCondition)
log.info("[embedding] request body:{}, url:{}", jsonBody, url); .build();
RestTemplate restTemplate = new RestTemplate(); List<RetrieveQueryResult> resultList = embeddingUtils.retrieveQuery(solvedQueryCollection, retrieveQuery,
ResponseEntity<List<RecallRetrievalResp>> embeddingResponseEntity = solvedQueryResultNum);
restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
new ParameterizedTypeReference<List<RecallRetrievalResp>>() { log.info("[embedding] recognize result body:{}", resultList);
});
log.info("[embedding] recognize result body:{}", embeddingResponseEntity);
List<RecallRetrievalResp> embeddingResps = embeddingResponseEntity.getBody();
Set<String> querySet = new HashSet<>(); Set<String> querySet = new HashSet<>();
if (CollectionUtils.isNotEmpty(embeddingResps)) { if (CollectionUtils.isNotEmpty(resultList)) {
for (RecallRetrievalResp embeddingResp : embeddingResps) { for (RetrieveQueryResult retrieveQueryResult : resultList) {
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval(); List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) { for (Retrieval retrieval : retrievals) {
if (queryText.equalsIgnoreCase(embeddingRetrieval.getQuery())) { if (queryText.equalsIgnoreCase(retrieval.getQuery())) {
continue; continue;
} }
if (querySet.contains(embeddingRetrieval.getQuery())) { if (querySet.contains(retrieval.getQuery())) {
continue; continue;
} }
String id = embeddingRetrieval.getId(); String id = retrieval.getId();
SolvedQueryRecallResp solvedQueryRecallResp = SolvedQueryRecallResp.builder() SolvedQueryRecallResp solvedQueryRecallResp = SolvedQueryRecallResp.builder()
.queryText(embeddingRetrieval.getQuery()) .queryText(retrieval.getQuery())
.queryId(getQueryId(id)).parseId(getParseId(id)) .queryId(getQueryId(id)).parseId(getParseId(id))
.build(); .build();
solvedQueryRecallResps.add(solvedQueryRecallResp); solvedQueryRecallResps.add(solvedQueryRecallResp);
querySet.add(embeddingRetrieval.getQuery()); querySet.add(retrieval.getQuery());
} }
} }
} }
} catch (Exception e) { } catch (Exception e) {
log.warn("recall similar solved query failed, queryText:{}", queryText); log.warn("recall similar solved query failed, queryText:{}", queryText);
} }
@@ -146,7 +143,8 @@ public class SolvedQueryManager {
log.info("[embedding] request body :{}, url:{}", jsonBody, url); log.info("[embedding] request body :{}, url:{}", jsonBody, url);
RestTemplate restTemplate = new RestTemplate(); RestTemplate restTemplate = new RestTemplate();
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl, ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {}); HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {
});
log.info("[embedding] result body:{}", responseEntity); log.info("[embedding] result body:{}", responseEntity);
return responseEntity; return responseEntity;
} catch (Exception e) { } catch (Exception e) {

View File

@@ -14,22 +14,21 @@ public class EmbeddingConfig {
@Value("${embedding.recognize.path:/preset_query_retrival}") @Value("${embedding.recognize.path:/preset_query_retrival}")
private String recognizePath; private String recognizePath;
@Value("${embedding.delete.path:/preset_delete_by_ids}") @Value("${embedding.preset.collection:preset_query_collection}")
private String deletePath; private String presetCollection;
@Value("${embedding.add.path:/preset_query_add}")
private String addPath;
@Value("${embedding.nResult:1}") @Value("${embedding.nResult:1}")
private String nResult; private int nResult;
@Value("${embedding.solvedQuery.recall.path:/solved_query_retrival}") @Value("${embedding.solved.query.collection:solved_query_collection}")
private String solvedQueryRecallPath; private String solvedQueryCollection;
@Value("${embedding.solvedQuery.add.path:/solved_query_add}")
private String solvedQueryAddPath;
@Value("${embedding.solved.query.nResult:5}") @Value("${embedding.solved.query.nResult:5}")
private String solvedQueryResultNum; private int solvedQueryResultNum;
@Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}")
private String metricAnalyzeQueryCollection;
} }

View File

@@ -81,7 +81,7 @@ public class EmbeddingUtils {
return embeddingCollections.stream().map(EmbeddingCollection::getName).collect(Collectors.toList()); return embeddingCollections.stream().map(EmbeddingCollection::getName).collect(Collectors.toList());
} }
private ResponseEntity doRequest(String url, String jsonBody, HttpMethod httpMethod) { public ResponseEntity doRequest(String url, String jsonBody, HttpMethod httpMethod) {
try { try {
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON); headers.setContentType(MediaType.APPLICATION_JSON);
@@ -94,11 +94,12 @@ public class EmbeddingUtils {
entity = new HttpEntity<>(jsonBody, headers); entity = new HttpEntity<>(jsonBody, headers);
} }
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl, ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
httpMethod, entity, new ParameterizedTypeReference<String>() {}); httpMethod, entity, new ParameterizedTypeReference<String>() {
});
log.info("[embedding] url :{} result body:{}", url, responseEntity); log.info("[embedding] url :{} result body:{}", url, responseEntity);
return responseEntity; return responseEntity;
} catch (Throwable e) { } catch (Throwable e) {
log.warn("connect to embedding service failed, url:{}", url); log.warn("doRequest service failed, url:" + url, e);
} }
return ResponseEntity.of(Optional.empty()); return ResponseEntity.of(Optional.empty());
} }

View File

@@ -2,21 +2,12 @@ package com.tencent.supersonic.integration;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.StandaloneLauncher; import com.tencent.supersonic.StandaloneLauncher;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.llm.analytics.LLMAnswerResp; import com.tencent.supersonic.chat.query.llm.analytics.LLMAnswerResp;
import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.QueryService; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean; import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
@@ -31,37 +22,21 @@ public class MetricInterpretTest {
@MockBean @MockBean
private AgentService agentService; private AgentService agentService;
@MockBean
private PluginManager pluginManager;
@MockBean @MockBean
private EmbeddingConfig embeddingConfig; private EmbeddingConfig embeddingConfig;
@Autowired @MockBean
@Qualifier("chatQueryService") private EmbeddingUtils embeddingUtils;
private QueryService queryService;
@Test @Test
public void testMetricInterpret() throws Exception { public void testMetricInterpret() throws Exception {
MockConfiguration.mockAgent(agentService); MockConfiguration.mockAgent(agentService);
MockConfiguration.mockEmbeddingUrl(embeddingConfig); MockConfiguration.mockEmbeddingUrl(embeddingConfig);
LLMAnswerResp lLmAnswerResp = new LLMAnswerResp(); LLMAnswerResp lLmAnswerResp = new LLMAnswerResp();
lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多"); lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多");
MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况",
DataUtils.getAgent().getId());
ParseResp parseResp = queryService.performParsing(queryReq); MockConfiguration.embeddingUtils(embeddingUtils, ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
ExecuteQueryReq executeReq = ExecuteQueryReq.builder().user(queryReq.getUser())
.chatId(parseResp.getChatId())
.queryId(parseResp.getQueryId())
.queryText(parseResp.getQueryText())
.parseInfo(parseResp.getCandidateParses().get(0))
.parseId(parseResp.getCandidateParses().get(0).getId())
.build();
QueryResult queryResult = queryService.performExecution(executeReq);
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistantMessage());
} }
} }

View File

@@ -1,30 +1,30 @@
package com.tencent.supersonic.integration; package com.tencent.supersonic.integration;
import static org.mockito.ArgumentMatchers.anyObject;
import static org.mockito.Mockito.when;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.when;
@Configuration @Configuration
@Slf4j @Slf4j
public class MockConfiguration { public class MockConfiguration {
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) { public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
RecallRetrievalResp embeddingResp = new RecallRetrievalResp(); RetrieveQueryResult embeddingResp = new RetrieveQueryResult();
RecallRetrieval embeddingRetrieval = new RecallRetrieval(); Retrieval embeddingRetrieval = new Retrieval();
embeddingRetrieval.setId(id); embeddingRetrieval.setId(id);
embeddingRetrieval.setPresetId(id); embeddingRetrieval.setDistance(0.15);
embeddingRetrieval.setDistance("0.15");
embeddingResp.setQuery(text); embeddingResp.setQuery(text);
embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval)); embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval));
when(pluginManager.recognize(text)).thenReturn(embeddingResp); when(pluginManager.recognize(text)).thenReturn(embeddingResp);
@@ -34,13 +34,11 @@ public class MockConfiguration {
when(embeddingConfig.getUrl()).thenReturn("test"); when(embeddingConfig.getUrl()).thenReturn("test");
} }
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path,
ResponseEntity<String> responseEntity) {
when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity);
}
public static void mockAgent(AgentService agentService) { public static void mockAgent(AgentService agentService) {
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent()); when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
} }
public static void embeddingUtils(EmbeddingUtils embeddingUtils, ResponseEntity<String> responseEntity) {
when(embeddingUtils.doRequest(anyObject(), anyObject(), anyObject())).thenReturn(responseEntity);
}
} }