[improvement](chat) Unified vector-related interfaces to go through EmbeddingUtils. (#476)

This commit is contained in:
lexluo09
2023-12-06 14:50:57 +08:00
committed by GitHub
parent 9aa5c93d9d
commit ed0f856438
9 changed files with 168 additions and 208 deletions

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.parser.plugin.embedding;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
import com.tencent.supersonic.chat.parser.LLMProxy;
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
import com.tencent.supersonic.chat.plugin.Plugin;
@@ -12,6 +12,8 @@ import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
@@ -40,13 +42,13 @@ public class EmbeddingRecallParser extends PluginParser {
@Override
public PluginRecallResult recallPlugin(QueryContext queryContext) {
String text = queryContext.getRequest().getQueryText();
List<RecallRetrieval> embeddingRetrievals = embeddingRecall(text);
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return null;
}
List<Plugin> plugins = getPluginList(queryContext);
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
if (plugin == null) {
continue;
@@ -59,7 +61,7 @@ public class EmbeddingRecallParser extends PluginParser {
continue;
}
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double distance = Double.parseDouble(embeddingRetrieval.getDistance());
double distance = embeddingRetrieval.getDistance();
double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
return PluginRecallResult.builder()
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
@@ -68,14 +70,15 @@ public class EmbeddingRecallParser extends PluginParser {
return null;
}
public List<RecallRetrieval> embeddingRecall(String embeddingText) {
public List<Retrieval> embeddingRecall(String embeddingText) {
try {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
RecallRetrievalResp embeddingResp = pluginManager.recognize(embeddingText);
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
RetrieveQueryResult embeddingResp = pluginManager.recognize(embeddingText);
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o ->
Math.abs(Double.parseDouble(o.getDistance())))).collect(Collectors.toList());
Math.abs(o.getDistance()))).collect(Collectors.toList());
embeddingResp.setRetrieval(embeddingRetrievals);
}
return embeddingRetrievals;

View File

@@ -3,17 +3,14 @@ package com.tencent.supersonic.chat.plugin;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.AgentToolType;
import com.tencent.supersonic.chat.agent.PluginTool;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
@@ -21,31 +18,28 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption;
import com.tencent.supersonic.chat.query.plugin.WebBase;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.PluginService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import java.net.URI;
import java.util.List;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Set;
import java.util.Optional;
import java.util.Collections;
import java.util.HashSet;
import java.util.HashMap;
import java.util.Objects;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.util.Strings;
import org.springframework.context.event.EventListener;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.ResponseEntity;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.HttpEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
@Slf4j
@Component
@@ -55,9 +49,12 @@ public class PluginManager {
private RestTemplate restTemplate;
public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate) {
private EmbeddingUtils embeddingUtils;
public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate, EmbeddingUtils embeddingUtils) {
this.embeddingConfig = embeddingConfig;
this.restTemplate = restTemplate;
this.embeddingUtils = embeddingUtils;
}
public static List<Plugin> getPluginAgentCanSupport(Integer agentId) {
@@ -124,96 +121,77 @@ public class PluginManager {
}
}
public void requestEmbeddingPluginDelete(Set<String> ids) {
if (CollectionUtils.isEmpty(ids)) {
public void requestEmbeddingPluginDelete(Set<String> queryIds) {
if (CollectionUtils.isEmpty(queryIds)) {
return;
}
doRequest(embeddingConfig.getDeletePath(), JSONObject.toJSONString(ids));
String presetCollection = embeddingConfig.getPresetCollection();
List<EmbeddingQuery> queries = new ArrayList<>();
for (String id : queryIds) {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(id);
queries.add(embeddingQuery);
}
embeddingUtils.deleteQuery(presetCollection, queries);
}
public void requestEmbeddingPluginAdd(List<Map<String, String>> maps) {
if (CollectionUtils.isEmpty(maps)) {
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
if (CollectionUtils.isEmpty(queries)) {
return;
}
doRequest(embeddingConfig.getAddPath(), JSONObject.toJSONString(maps));
}
public ResponseEntity<String> doRequest(String path, String jsonBody) {
if (Strings.isEmpty(embeddingConfig.getUrl())) {
return ResponseEntity.of(Optional.empty());
}
String url = embeddingConfig.getUrl() + path;
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
URI requestUrl = UriComponentsBuilder
.fromHttpUrl(url).build().encode().toUri();
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
log.info("[embedding] equest body :{}, url:{}", jsonBody, url);
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {});
log.info("[embedding] result body:{}", responseEntity);
return responseEntity;
} catch (Throwable e) {
log.warn("connect to embedding service failed, url:{}", url);
}
return ResponseEntity.of(Optional.empty());
String presetCollection = embeddingConfig.getPresetCollection();
embeddingUtils.addQuery(presetCollection, queries);
}
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
requestEmbeddingPluginAdd(convert(plugins));
}
public RecallRetrievalResp recognize(String embeddingText) {
String url = embeddingConfig.getUrl() + embeddingConfig.getRecognizePath() + "?n_results="
+ embeddingConfig.getNResult();
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
URI requestUrl = UriComponentsBuilder
.fromHttpUrl(url).build().encode().toUri();
String jsonBody = JSONObject.toJSONString(Lists.newArrayList(embeddingText));
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
ResponseEntity<List<RecallRetrievalResp>> embeddingResponseEntity =
restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
new ParameterizedTypeReference<List<RecallRetrievalResp>>() {
});
log.info("[embedding] recognize result body:{}", embeddingResponseEntity);
List<RecallRetrievalResp> embeddingResps = embeddingResponseEntity.getBody();
if (CollectionUtils.isNotEmpty(embeddingResps)) {
for (RecallRetrievalResp embeddingResp : embeddingResps) {
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
public RetrieveQueryResult recognize(String embeddingText) {
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
.queryTextsList(Collections.singletonList(embeddingText))
.build();
List<RetrieveQueryResult> resultList = embeddingUtils.retrieveQuery(embeddingConfig.getPresetCollection(),
retrieveQuery, embeddingConfig.getNResult());
if (CollectionUtils.isNotEmpty(resultList)) {
for (RetrieveQueryResult embeddingResp : resultList) {
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
embeddingRetrieval.setId(getPluginIdFromEmbeddingId(embeddingRetrieval.getId()));
}
}
return embeddingResps.get(0);
return resultList.get(0);
}
throw new RuntimeException("get embedding result failed");
}
public List<Map<String, String>> convert(List<Plugin> plugins) {
List<Map<String, String>> maps = Lists.newArrayList();
public List<EmbeddingQuery> convert(List<Plugin> plugins) {
List<EmbeddingQuery> queries = Lists.newArrayList();
for (Plugin plugin : plugins) {
List<String> exampleQuestions = plugin.getExampleQuestionList();
int num = 0;
for (String pattern : exampleQuestions) {
Map<String, String> map = new HashMap<>();
map.put("preset_query_id", generateUniqueEmbeddingId(num, plugin.getId()));
map.put("preset_query", pattern);
maps.add(map);
EmbeddingQuery query = new EmbeddingQuery();
query.setQueryId(generateUniqueEmbeddingId(num, plugin.getId()));
query.setQuery(pattern);
queries.add(query);
num++;
}
}
return maps;
return queries;
}
private Set<String> getEmbeddingId(List<Plugin> plugins) {
Set<String> embeddingIdSet = new HashSet<>();
for (Map<String, String> map : convert(plugins)) {
embeddingIdSet.add(map.get("preset_query_id"));
for (EmbeddingQuery query : convert(plugins)) {
embeddingIdSet.add(query.getQueryId());
}
return embeddingIdSet;
}

View File

@@ -10,16 +10,17 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
import com.tencent.supersonic.chat.config.OptimizationConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
import com.tencent.supersonic.chat.utils.ComponentFactory;
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Aggregator;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.QueryType;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
import java.util.HashMap;
@@ -30,6 +31,7 @@ import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpMethod;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@@ -149,12 +151,18 @@ public class MetricAnalyzeQuery extends LLMSemanticQuery {
}
public String fetchInterpret(String queryText, String dataText) {
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
LLMAnswerReq lLmAnswerReq = new LLMAnswerReq();
lLmAnswerReq.setQueryText(queryText);
lLmAnswerReq.setPluginOutput(dataText);
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
JSONObject.toJSONString(lLmAnswerReq));
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String metricAnalyzeQueryCollection = embeddingConfig.getMetricAnalyzeQueryCollection();
String url = String.format("%s/retrieve_query?collection_name=%s", embeddingConfig.getUrl(),
metricAnalyzeQueryCollection);
ResponseEntity<String> responseEntity = embeddingUtils.doRequest(url, JSONObject.toJSONString(lLmAnswerReq),
HttpMethod.POST);
LLMAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLMAnswerResp.class);
if (lLmAnswerResp != null) {
return lLmAnswerResp.getAssistantMessage();

View File

@@ -199,7 +199,7 @@ public class QueryServiceImpl implements QueryService {
queryReq.getUser().getName(), queryReq.getChatId().longValue());
queryResult.setChatContext(parseInfo);
// update chat context after a successful semantic query
if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) {
if (QueryState.SUCCESS.equals(queryResult.getQueryState())) {
chatCtx.setParseInfo(parseInfo);
chatService.updateContext(chatCtx);
saveSolvedQuery(queryReq, parseInfo, chatQueryDO, queryResult);

View File

@@ -1,13 +1,21 @@
package com.tencent.supersonic.chat.utils;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@@ -21,13 +29,6 @@ import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
@Slf4j
@Component
@@ -35,8 +36,11 @@ public class SolvedQueryManager {
private EmbeddingConfig embeddingConfig;
public SolvedQueryManager(EmbeddingConfig embeddingConfig) {
private EmbeddingUtils embeddingUtils;
public SolvedQueryManager(EmbeddingConfig embeddingConfig, EmbeddingUtils embeddingUtils) {
this.embeddingConfig = embeddingConfig;
this.embeddingUtils = embeddingUtils;
}
public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) {
@@ -46,15 +50,16 @@ public class SolvedQueryManager {
String queryText = solvedQueryReq.getQueryText();
try {
String uniqueId = generateUniqueId(solvedQueryReq.getQueryId(), solvedQueryReq.getParseId());
Map<String, Object> requestMap = new HashMap<>();
requestMap.put("query", queryText);
requestMap.put("query_id", uniqueId);
Map<String, Object> metaData = new HashMap<>();
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(uniqueId);
embeddingQuery.setQuery(queryText);
Map<String, String> metaData = new HashMap<>();
metaData.put("modelId", String.valueOf(solvedQueryReq.getModelId()));
metaData.put("agentId", String.valueOf(solvedQueryReq.getAgentId()));
requestMap.put("metadata", metaData);
doRequest(embeddingConfig.getSolvedQueryAddPath(),
JSONObject.toJSONString(Lists.newArrayList(requestMap)));
embeddingQuery.setMetadata(metaData);
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
embeddingUtils.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
} catch (Exception e) {
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
}
@@ -66,49 +71,41 @@ public class SolvedQueryManager {
}
List<SolvedQueryRecallResp> solvedQueryRecallResps = Lists.newArrayList();
try {
String url = embeddingConfig.getUrl() + embeddingConfig.getSolvedQueryRecallPath() + "?n_results="
+ embeddingConfig.getSolvedQueryResultNum();
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.setLocation(URI.create(url));
URI requestUrl = UriComponentsBuilder
.fromHttpUrl(url).build().encode().toUri();
Map<String, Object> map = new HashMap<>();
map.put("queryTextsList", Lists.newArrayList(queryText));
Map<String, Object> filterCondition = new HashMap<>();
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
int solvedQueryResultNum = embeddingConfig.getSolvedQueryResultNum();
Map<String, String> filterCondition = new HashMap<>();
filterCondition.put("agentId", String.valueOf(agentId));
map.put("filterCondition", filterCondition);
String jsonBody = JSONObject.toJSONString(map, SerializerFeature.WriteMapNullValue);
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
RestTemplate restTemplate = new RestTemplate();
ResponseEntity<List<RecallRetrievalResp>> embeddingResponseEntity =
restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
new ParameterizedTypeReference<List<RecallRetrievalResp>>() {
});
log.info("[embedding] recognize result body:{}", embeddingResponseEntity);
List<RecallRetrievalResp> embeddingResps = embeddingResponseEntity.getBody();
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
.queryTextsList(Lists.newArrayList(queryText))
.filterCondition(filterCondition)
.build();
List<RetrieveQueryResult> resultList = embeddingUtils.retrieveQuery(solvedQueryCollection, retrieveQuery,
solvedQueryResultNum);
log.info("[embedding] recognize result body:{}", resultList);
Set<String> querySet = new HashSet<>();
if (CollectionUtils.isNotEmpty(embeddingResps)) {
for (RecallRetrievalResp embeddingResp : embeddingResps) {
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
if (queryText.equalsIgnoreCase(embeddingRetrieval.getQuery())) {
if (CollectionUtils.isNotEmpty(resultList)) {
for (RetrieveQueryResult retrieveQueryResult : resultList) {
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
for (Retrieval retrieval : retrievals) {
if (queryText.equalsIgnoreCase(retrieval.getQuery())) {
continue;
}
if (querySet.contains(embeddingRetrieval.getQuery())) {
if (querySet.contains(retrieval.getQuery())) {
continue;
}
String id = embeddingRetrieval.getId();
String id = retrieval.getId();
SolvedQueryRecallResp solvedQueryRecallResp = SolvedQueryRecallResp.builder()
.queryText(embeddingRetrieval.getQuery())
.queryText(retrieval.getQuery())
.queryId(getQueryId(id)).parseId(getParseId(id))
.build();
solvedQueryRecallResps.add(solvedQueryRecallResp);
querySet.add(embeddingRetrieval.getQuery());
querySet.add(retrieval.getQuery());
}
}
}
} catch (Exception e) {
log.warn("recall similar solved query failed, queryText:{}", queryText);
}
@@ -146,7 +143,8 @@ public class SolvedQueryManager {
log.info("[embedding] request body :{}, url:{}", jsonBody, url);
RestTemplate restTemplate = new RestTemplate();
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {});
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {
});
log.info("[embedding] result body:{}", responseEntity);
return responseEntity;
} catch (Exception e) {

View File

@@ -14,22 +14,21 @@ public class EmbeddingConfig {
@Value("${embedding.recognize.path:/preset_query_retrival}")
private String recognizePath;
@Value("${embedding.delete.path:/preset_delete_by_ids}")
private String deletePath;
@Value("${embedding.add.path:/preset_query_add}")
private String addPath;
@Value("${embedding.preset.collection:preset_query_collection}")
private String presetCollection;
@Value("${embedding.nResult:1}")
private String nResult;
private int nResult;
@Value("${embedding.solvedQuery.recall.path:/solved_query_retrival}")
private String solvedQueryRecallPath;
@Value("${embedding.solvedQuery.add.path:/solved_query_add}")
private String solvedQueryAddPath;
@Value("${embedding.solved.query.collection:solved_query_collection}")
private String solvedQueryCollection;
@Value("${embedding.solved.query.nResult:5}")
private String solvedQueryResultNum;
private int solvedQueryResultNum;
@Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}")
private String metricAnalyzeQueryCollection;
}

View File

@@ -81,7 +81,7 @@ public class EmbeddingUtils {
return embeddingCollections.stream().map(EmbeddingCollection::getName).collect(Collectors.toList());
}
private ResponseEntity doRequest(String url, String jsonBody, HttpMethod httpMethod) {
public ResponseEntity doRequest(String url, String jsonBody, HttpMethod httpMethod) {
try {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
@@ -94,11 +94,12 @@ public class EmbeddingUtils {
entity = new HttpEntity<>(jsonBody, headers);
}
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
httpMethod, entity, new ParameterizedTypeReference<String>() {});
httpMethod, entity, new ParameterizedTypeReference<String>() {
});
log.info("[embedding] url :{} result body:{}", url, responseEntity);
return responseEntity;
} catch (Throwable e) {
log.warn("connect to embedding service failed, url:{}", url);
log.warn("doRequest service failed, url:" + url, e);
}
return ResponseEntity.of(Optional.empty());
}

View File

@@ -2,21 +2,12 @@ package com.tencent.supersonic.integration;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.StandaloneLauncher;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.llm.analytics.LLMAnswerResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.util.DataUtils;
import org.junit.Assert;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.http.ResponseEntity;
@@ -31,37 +22,21 @@ public class MetricInterpretTest {
@MockBean
private AgentService agentService;
@MockBean
private PluginManager pluginManager;
@MockBean
private EmbeddingConfig embeddingConfig;
@Autowired
@Qualifier("chatQueryService")
private QueryService queryService;
@MockBean
private EmbeddingUtils embeddingUtils;
@Test
public void testMetricInterpret() throws Exception {
MockConfiguration.mockAgent(agentService);
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
LLMAnswerResp lLmAnswerResp = new LLMAnswerResp();
lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多");
MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况",
DataUtils.getAgent().getId());
ParseResp parseResp = queryService.performParsing(queryReq);
ExecuteQueryReq executeReq = ExecuteQueryReq.builder().user(queryReq.getUser())
.chatId(parseResp.getChatId())
.queryId(parseResp.getQueryId())
.queryText(parseResp.getQueryText())
.parseInfo(parseResp.getCandidateParses().get(0))
.parseId(parseResp.getCandidateParses().get(0).getId())
.build();
QueryResult queryResult = queryService.performExecution(executeReq);
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistantMessage());
MockConfiguration.embeddingUtils(embeddingUtils, ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
}
}

View File

@@ -1,30 +1,30 @@
package com.tencent.supersonic.integration;
import static org.mockito.ArgumentMatchers.anyObject;
import static org.mockito.Mockito.when;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.ResponseEntity;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.when;
@Configuration
@Slf4j
public class MockConfiguration {
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
RecallRetrievalResp embeddingResp = new RecallRetrievalResp();
RecallRetrieval embeddingRetrieval = new RecallRetrieval();
RetrieveQueryResult embeddingResp = new RetrieveQueryResult();
Retrieval embeddingRetrieval = new Retrieval();
embeddingRetrieval.setId(id);
embeddingRetrieval.setPresetId(id);
embeddingRetrieval.setDistance("0.15");
embeddingRetrieval.setDistance(0.15);
embeddingResp.setQuery(text);
embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval));
when(pluginManager.recognize(text)).thenReturn(embeddingResp);
@@ -34,13 +34,11 @@ public class MockConfiguration {
when(embeddingConfig.getUrl()).thenReturn("test");
}
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path,
ResponseEntity<String> responseEntity) {
when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity);
}
public static void mockAgent(AgentService agentService) {
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
}
public static void embeddingUtils(EmbeddingUtils embeddingUtils, ResponseEntity<String> responseEntity) {
when(embeddingUtils.doRequest(anyObject(), anyObject(), anyObject())).thenReturn(responseEntity);
}
}