mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement](chat) Unified vector-related interfaces to go through EmbeddingUtils. (#476)
This commit is contained in:
@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.parser.plugin.embedding;
|
|||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
|
|
||||||
import com.tencent.supersonic.chat.parser.LLMProxy;
|
import com.tencent.supersonic.chat.parser.LLMProxy;
|
||||||
|
import com.tencent.supersonic.chat.parser.PythonLLMProxy;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
|
import com.tencent.supersonic.chat.parser.plugin.ParseMode;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||||
@@ -12,6 +12,8 @@ import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
|||||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
@@ -40,13 +42,13 @@ public class EmbeddingRecallParser extends PluginParser {
|
|||||||
@Override
|
@Override
|
||||||
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
||||||
String text = queryContext.getRequest().getQueryText();
|
String text = queryContext.getRequest().getQueryText();
|
||||||
List<RecallRetrieval> embeddingRetrievals = embeddingRecall(text);
|
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
List<Plugin> plugins = getPluginList(queryContext);
|
List<Plugin> plugins = getPluginList(queryContext);
|
||||||
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
|
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
|
||||||
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
|
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||||
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||||
if (plugin == null) {
|
if (plugin == null) {
|
||||||
continue;
|
continue;
|
||||||
@@ -59,7 +61,7 @@ public class EmbeddingRecallParser extends PluginParser {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||||
double distance = Double.parseDouble(embeddingRetrieval.getDistance());
|
double distance = embeddingRetrieval.getDistance();
|
||||||
double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
|
double score = queryContext.getRequest().getQueryText().length() * (1 - distance);
|
||||||
return PluginRecallResult.builder()
|
return PluginRecallResult.builder()
|
||||||
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
|
.plugin(plugin).modelIds(modelList).score(score).distance(distance).build();
|
||||||
@@ -68,14 +70,15 @@ public class EmbeddingRecallParser extends PluginParser {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<RecallRetrieval> embeddingRecall(String embeddingText) {
|
public List<Retrieval> embeddingRecall(String embeddingText) {
|
||||||
try {
|
try {
|
||||||
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
|
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
|
||||||
RecallRetrievalResp embeddingResp = pluginManager.recognize(embeddingText);
|
RetrieveQueryResult embeddingResp = pluginManager.recognize(embeddingText);
|
||||||
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
|
||||||
|
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||||
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
|
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||||
embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o ->
|
embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o ->
|
||||||
Math.abs(Double.parseDouble(o.getDistance())))).collect(Collectors.toList());
|
Math.abs(o.getDistance()))).collect(Collectors.toList());
|
||||||
embeddingResp.setRetrieval(embeddingRetrievals);
|
embeddingResp.setRetrieval(embeddingRetrievals);
|
||||||
}
|
}
|
||||||
return embeddingRetrievals;
|
return embeddingRetrievals;
|
||||||
|
|||||||
@@ -3,17 +3,14 @@ package com.tencent.supersonic.chat.plugin;
|
|||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.agent.Agent;
|
import com.tencent.supersonic.chat.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.agent.AgentToolType;
|
import com.tencent.supersonic.chat.agent.AgentToolType;
|
||||||
import com.tencent.supersonic.chat.agent.PluginTool;
|
import com.tencent.supersonic.chat.agent.PluginTool;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
|
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
|
||||||
import com.tencent.supersonic.chat.plugin.event.PluginDelEvent;
|
import com.tencent.supersonic.chat.plugin.event.PluginDelEvent;
|
||||||
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
|
import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent;
|
||||||
@@ -21,31 +18,28 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption;
|
|||||||
import com.tencent.supersonic.chat.query.plugin.WebBase;
|
import com.tencent.supersonic.chat.query.plugin.WebBase;
|
||||||
import com.tencent.supersonic.chat.service.AgentService;
|
import com.tencent.supersonic.chat.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.service.PluginService;
|
import com.tencent.supersonic.chat.service.PluginService;
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import java.net.URI;
|
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||||
import java.util.List;
|
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
import java.util.Set;
|
import java.util.Collections;
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.HashMap;
|
import java.util.List;
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.apache.logging.log4j.util.Strings;
|
|
||||||
import org.springframework.context.event.EventListener;
|
import org.springframework.context.event.EventListener;
|
||||||
import org.springframework.core.ParameterizedTypeReference;
|
|
||||||
import org.springframework.http.ResponseEntity;
|
|
||||||
import org.springframework.http.HttpMethod;
|
|
||||||
import org.springframework.http.HttpHeaders;
|
|
||||||
import org.springframework.http.MediaType;
|
|
||||||
import org.springframework.http.HttpEntity;
|
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import org.springframework.web.client.RestTemplate;
|
import org.springframework.web.client.RestTemplate;
|
||||||
import org.springframework.web.util.UriComponentsBuilder;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Component
|
@Component
|
||||||
@@ -55,9 +49,12 @@ public class PluginManager {
|
|||||||
|
|
||||||
private RestTemplate restTemplate;
|
private RestTemplate restTemplate;
|
||||||
|
|
||||||
public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate) {
|
private EmbeddingUtils embeddingUtils;
|
||||||
|
|
||||||
|
public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate, EmbeddingUtils embeddingUtils) {
|
||||||
this.embeddingConfig = embeddingConfig;
|
this.embeddingConfig = embeddingConfig;
|
||||||
this.restTemplate = restTemplate;
|
this.restTemplate = restTemplate;
|
||||||
|
this.embeddingUtils = embeddingUtils;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<Plugin> getPluginAgentCanSupport(Integer agentId) {
|
public static List<Plugin> getPluginAgentCanSupport(Integer agentId) {
|
||||||
@@ -124,96 +121,77 @@ public class PluginManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void requestEmbeddingPluginDelete(Set<String> ids) {
|
public void requestEmbeddingPluginDelete(Set<String> queryIds) {
|
||||||
if (CollectionUtils.isEmpty(ids)) {
|
if (CollectionUtils.isEmpty(queryIds)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
doRequest(embeddingConfig.getDeletePath(), JSONObject.toJSONString(ids));
|
String presetCollection = embeddingConfig.getPresetCollection();
|
||||||
|
|
||||||
|
List<EmbeddingQuery> queries = new ArrayList<>();
|
||||||
|
for (String id : queryIds) {
|
||||||
|
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||||
|
embeddingQuery.setQueryId(id);
|
||||||
|
queries.add(embeddingQuery);
|
||||||
|
}
|
||||||
|
embeddingUtils.deleteQuery(presetCollection, queries);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void requestEmbeddingPluginAdd(List<Map<String, String>> maps) {
|
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
|
||||||
if (CollectionUtils.isEmpty(maps)) {
|
if (CollectionUtils.isEmpty(queries)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
doRequest(embeddingConfig.getAddPath(), JSONObject.toJSONString(maps));
|
String presetCollection = embeddingConfig.getPresetCollection();
|
||||||
}
|
embeddingUtils.addQuery(presetCollection, queries);
|
||||||
|
|
||||||
public ResponseEntity<String> doRequest(String path, String jsonBody) {
|
|
||||||
if (Strings.isEmpty(embeddingConfig.getUrl())) {
|
|
||||||
return ResponseEntity.of(Optional.empty());
|
|
||||||
}
|
|
||||||
String url = embeddingConfig.getUrl() + path;
|
|
||||||
try {
|
|
||||||
HttpHeaders headers = new HttpHeaders();
|
|
||||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
|
||||||
headers.setLocation(URI.create(url));
|
|
||||||
URI requestUrl = UriComponentsBuilder
|
|
||||||
.fromHttpUrl(url).build().encode().toUri();
|
|
||||||
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
|
|
||||||
log.info("[embedding] equest body :{}, url:{}", jsonBody, url);
|
|
||||||
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
|
|
||||||
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {});
|
|
||||||
log.info("[embedding] result body:{}", responseEntity);
|
|
||||||
return responseEntity;
|
|
||||||
} catch (Throwable e) {
|
|
||||||
log.warn("connect to embedding service failed, url:{}", url);
|
|
||||||
}
|
|
||||||
return ResponseEntity.of(Optional.empty());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
|
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
|
||||||
requestEmbeddingPluginAdd(convert(plugins));
|
requestEmbeddingPluginAdd(convert(plugins));
|
||||||
}
|
}
|
||||||
|
|
||||||
public RecallRetrievalResp recognize(String embeddingText) {
|
public RetrieveQueryResult recognize(String embeddingText) {
|
||||||
String url = embeddingConfig.getUrl() + embeddingConfig.getRecognizePath() + "?n_results="
|
|
||||||
+ embeddingConfig.getNResult();
|
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
|
||||||
HttpHeaders headers = new HttpHeaders();
|
|
||||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||||
headers.setLocation(URI.create(url));
|
.queryTextsList(Collections.singletonList(embeddingText))
|
||||||
URI requestUrl = UriComponentsBuilder
|
.build();
|
||||||
.fromHttpUrl(url).build().encode().toUri();
|
|
||||||
String jsonBody = JSONObject.toJSONString(Lists.newArrayList(embeddingText));
|
List<RetrieveQueryResult> resultList = embeddingUtils.retrieveQuery(embeddingConfig.getPresetCollection(),
|
||||||
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
|
retrieveQuery, embeddingConfig.getNResult());
|
||||||
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
|
|
||||||
ResponseEntity<List<RecallRetrievalResp>> embeddingResponseEntity =
|
|
||||||
restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
if (CollectionUtils.isNotEmpty(resultList)) {
|
||||||
new ParameterizedTypeReference<List<RecallRetrievalResp>>() {
|
for (RetrieveQueryResult embeddingResp : resultList) {
|
||||||
});
|
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
||||||
log.info("[embedding] recognize result body:{}", embeddingResponseEntity);
|
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||||
List<RecallRetrievalResp> embeddingResps = embeddingResponseEntity.getBody();
|
|
||||||
if (CollectionUtils.isNotEmpty(embeddingResps)) {
|
|
||||||
for (RecallRetrievalResp embeddingResp : embeddingResps) {
|
|
||||||
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
|
||||||
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
|
|
||||||
embeddingRetrieval.setId(getPluginIdFromEmbeddingId(embeddingRetrieval.getId()));
|
embeddingRetrieval.setId(getPluginIdFromEmbeddingId(embeddingRetrieval.getId()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return embeddingResps.get(0);
|
return resultList.get(0);
|
||||||
}
|
}
|
||||||
throw new RuntimeException("get embedding result failed");
|
throw new RuntimeException("get embedding result failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<Map<String, String>> convert(List<Plugin> plugins) {
|
public List<EmbeddingQuery> convert(List<Plugin> plugins) {
|
||||||
List<Map<String, String>> maps = Lists.newArrayList();
|
List<EmbeddingQuery> queries = Lists.newArrayList();
|
||||||
for (Plugin plugin : plugins) {
|
for (Plugin plugin : plugins) {
|
||||||
List<String> exampleQuestions = plugin.getExampleQuestionList();
|
List<String> exampleQuestions = plugin.getExampleQuestionList();
|
||||||
int num = 0;
|
int num = 0;
|
||||||
for (String pattern : exampleQuestions) {
|
for (String pattern : exampleQuestions) {
|
||||||
Map<String, String> map = new HashMap<>();
|
EmbeddingQuery query = new EmbeddingQuery();
|
||||||
map.put("preset_query_id", generateUniqueEmbeddingId(num, plugin.getId()));
|
query.setQueryId(generateUniqueEmbeddingId(num, plugin.getId()));
|
||||||
map.put("preset_query", pattern);
|
query.setQuery(pattern);
|
||||||
maps.add(map);
|
queries.add(query);
|
||||||
num++;
|
num++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return maps;
|
return queries;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Set<String> getEmbeddingId(List<Plugin> plugins) {
|
private Set<String> getEmbeddingId(List<Plugin> plugins) {
|
||||||
Set<String> embeddingIdSet = new HashSet<>();
|
Set<String> embeddingIdSet = new HashSet<>();
|
||||||
for (Map<String, String> map : convert(plugins)) {
|
for (EmbeddingQuery query : convert(plugins)) {
|
||||||
embeddingIdSet.add(map.get("preset_query_id"));
|
embeddingIdSet.add(query.getQueryId());
|
||||||
}
|
}
|
||||||
return embeddingIdSet;
|
return embeddingIdSet;
|
||||||
}
|
}
|
||||||
@@ -243,7 +221,7 @@ public class PluginManager {
|
|||||||
}
|
}
|
||||||
Set<Long> matchedModel = Sets.newHashSet();
|
Set<Long> matchedModel = Sets.newHashSet();
|
||||||
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream()
|
Map<Long, List<ParamOption>> paramOptionMap = paramOptions.stream()
|
||||||
.collect(Collectors.groupingBy(ParamOption::getModelId));
|
.collect(Collectors.groupingBy(ParamOption::getModelId));
|
||||||
for (Long modelId : paramOptionMap.keySet()) {
|
for (Long modelId : paramOptionMap.keySet()) {
|
||||||
List<ParamOption> params = paramOptionMap.get(modelId);
|
List<ParamOption> params = paramOptionMap.get(modelId);
|
||||||
if (CollectionUtils.isEmpty(params)) {
|
if (CollectionUtils.isEmpty(params)) {
|
||||||
|
|||||||
@@ -10,16 +10,17 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||||
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
|
||||||
import com.tencent.supersonic.chat.query.QueryManager;
|
import com.tencent.supersonic.chat.query.QueryManager;
|
||||||
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery;
|
||||||
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
import com.tencent.supersonic.chat.utils.ComponentFactory;
|
||||||
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
import com.tencent.supersonic.chat.utils.QueryReqBuilder;
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||||
import com.tencent.supersonic.common.pojo.QueryType;
|
import com.tencent.supersonic.common.pojo.QueryType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||||
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp;
|
||||||
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
import com.tencent.supersonic.semantic.api.query.request.QueryStructReq;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -30,6 +31,7 @@ import java.util.stream.Collectors;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.calcite.sql.parser.SqlParseException;
|
import org.apache.calcite.sql.parser.SqlParseException;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.http.HttpMethod;
|
||||||
import org.springframework.http.ResponseEntity;
|
import org.springframework.http.ResponseEntity;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
@@ -149,12 +151,18 @@ public class MetricAnalyzeQuery extends LLMSemanticQuery {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public String fetchInterpret(String queryText, String dataText) {
|
public String fetchInterpret(String queryText, String dataText) {
|
||||||
PluginManager pluginManager = ContextUtils.getBean(PluginManager.class);
|
|
||||||
LLMAnswerReq lLmAnswerReq = new LLMAnswerReq();
|
LLMAnswerReq lLmAnswerReq = new LLMAnswerReq();
|
||||||
lLmAnswerReq.setQueryText(queryText);
|
lLmAnswerReq.setQueryText(queryText);
|
||||||
lLmAnswerReq.setPluginOutput(dataText);
|
lLmAnswerReq.setPluginOutput(dataText);
|
||||||
ResponseEntity<String> responseEntity = pluginManager.doRequest("answer_with_plugin_call",
|
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
|
||||||
JSONObject.toJSONString(lLmAnswerReq));
|
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||||
|
String metricAnalyzeQueryCollection = embeddingConfig.getMetricAnalyzeQueryCollection();
|
||||||
|
|
||||||
|
String url = String.format("%s/retrieve_query?collection_name=%s", embeddingConfig.getUrl(),
|
||||||
|
metricAnalyzeQueryCollection);
|
||||||
|
|
||||||
|
ResponseEntity<String> responseEntity = embeddingUtils.doRequest(url, JSONObject.toJSONString(lLmAnswerReq),
|
||||||
|
HttpMethod.POST);
|
||||||
LLMAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLMAnswerResp.class);
|
LLMAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLMAnswerResp.class);
|
||||||
if (lLmAnswerResp != null) {
|
if (lLmAnswerResp != null) {
|
||||||
return lLmAnswerResp.getAssistantMessage();
|
return lLmAnswerResp.getAssistantMessage();
|
||||||
|
|||||||
@@ -199,7 +199,7 @@ public class QueryServiceImpl implements QueryService {
|
|||||||
queryReq.getUser().getName(), queryReq.getChatId().longValue());
|
queryReq.getUser().getName(), queryReq.getChatId().longValue());
|
||||||
queryResult.setChatContext(parseInfo);
|
queryResult.setChatContext(parseInfo);
|
||||||
// update chat context after a successful semantic query
|
// update chat context after a successful semantic query
|
||||||
if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) {
|
if (QueryState.SUCCESS.equals(queryResult.getQueryState())) {
|
||||||
chatCtx.setParseInfo(parseInfo);
|
chatCtx.setParseInfo(parseInfo);
|
||||||
chatService.updateContext(chatCtx);
|
chatService.updateContext(chatCtx);
|
||||||
saveSolvedQuery(queryReq, parseInfo, chatQueryDO, queryResult);
|
saveSolvedQuery(queryReq, parseInfo, chatQueryDO, queryResult);
|
||||||
|
|||||||
@@ -1,13 +1,21 @@
|
|||||||
package com.tencent.supersonic.chat.utils;
|
package com.tencent.supersonic.chat.utils;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
|
||||||
import com.alibaba.fastjson.serializer.SerializerFeature;
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp;
|
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
|
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
|
import java.net.URI;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
import org.apache.commons.collections.CollectionUtils;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
@@ -21,13 +29,6 @@ import org.springframework.http.ResponseEntity;
|
|||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
import org.springframework.web.client.RestTemplate;
|
import org.springframework.web.client.RestTemplate;
|
||||||
import org.springframework.web.util.UriComponentsBuilder;
|
import org.springframework.web.util.UriComponentsBuilder;
|
||||||
import java.net.URI;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@Component
|
@Component
|
||||||
@@ -35,8 +36,11 @@ public class SolvedQueryManager {
|
|||||||
|
|
||||||
private EmbeddingConfig embeddingConfig;
|
private EmbeddingConfig embeddingConfig;
|
||||||
|
|
||||||
public SolvedQueryManager(EmbeddingConfig embeddingConfig) {
|
private EmbeddingUtils embeddingUtils;
|
||||||
|
|
||||||
|
public SolvedQueryManager(EmbeddingConfig embeddingConfig, EmbeddingUtils embeddingUtils) {
|
||||||
this.embeddingConfig = embeddingConfig;
|
this.embeddingConfig = embeddingConfig;
|
||||||
|
this.embeddingUtils = embeddingUtils;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) {
|
public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) {
|
||||||
@@ -46,15 +50,16 @@ public class SolvedQueryManager {
|
|||||||
String queryText = solvedQueryReq.getQueryText();
|
String queryText = solvedQueryReq.getQueryText();
|
||||||
try {
|
try {
|
||||||
String uniqueId = generateUniqueId(solvedQueryReq.getQueryId(), solvedQueryReq.getParseId());
|
String uniqueId = generateUniqueId(solvedQueryReq.getQueryId(), solvedQueryReq.getParseId());
|
||||||
Map<String, Object> requestMap = new HashMap<>();
|
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||||
requestMap.put("query", queryText);
|
embeddingQuery.setQueryId(uniqueId);
|
||||||
requestMap.put("query_id", uniqueId);
|
embeddingQuery.setQuery(queryText);
|
||||||
Map<String, Object> metaData = new HashMap<>();
|
|
||||||
|
Map<String, String> metaData = new HashMap<>();
|
||||||
metaData.put("modelId", String.valueOf(solvedQueryReq.getModelId()));
|
metaData.put("modelId", String.valueOf(solvedQueryReq.getModelId()));
|
||||||
metaData.put("agentId", String.valueOf(solvedQueryReq.getAgentId()));
|
metaData.put("agentId", String.valueOf(solvedQueryReq.getAgentId()));
|
||||||
requestMap.put("metadata", metaData);
|
embeddingQuery.setMetadata(metaData);
|
||||||
doRequest(embeddingConfig.getSolvedQueryAddPath(),
|
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||||
JSONObject.toJSONString(Lists.newArrayList(requestMap)));
|
embeddingUtils.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
|
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
|
||||||
}
|
}
|
||||||
@@ -66,49 +71,41 @@ public class SolvedQueryManager {
|
|||||||
}
|
}
|
||||||
List<SolvedQueryRecallResp> solvedQueryRecallResps = Lists.newArrayList();
|
List<SolvedQueryRecallResp> solvedQueryRecallResps = Lists.newArrayList();
|
||||||
try {
|
try {
|
||||||
String url = embeddingConfig.getUrl() + embeddingConfig.getSolvedQueryRecallPath() + "?n_results="
|
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||||
+ embeddingConfig.getSolvedQueryResultNum();
|
int solvedQueryResultNum = embeddingConfig.getSolvedQueryResultNum();
|
||||||
HttpHeaders headers = new HttpHeaders();
|
|
||||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
Map<String, String> filterCondition = new HashMap<>();
|
||||||
headers.setLocation(URI.create(url));
|
|
||||||
URI requestUrl = UriComponentsBuilder
|
|
||||||
.fromHttpUrl(url).build().encode().toUri();
|
|
||||||
Map<String, Object> map = new HashMap<>();
|
|
||||||
map.put("queryTextsList", Lists.newArrayList(queryText));
|
|
||||||
Map<String, Object> filterCondition = new HashMap<>();
|
|
||||||
filterCondition.put("agentId", String.valueOf(agentId));
|
filterCondition.put("agentId", String.valueOf(agentId));
|
||||||
map.put("filterCondition", filterCondition);
|
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||||
String jsonBody = JSONObject.toJSONString(map, SerializerFeature.WriteMapNullValue);
|
.queryTextsList(Lists.newArrayList(queryText))
|
||||||
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
|
.filterCondition(filterCondition)
|
||||||
log.info("[embedding] request body:{}, url:{}", jsonBody, url);
|
.build();
|
||||||
RestTemplate restTemplate = new RestTemplate();
|
List<RetrieveQueryResult> resultList = embeddingUtils.retrieveQuery(solvedQueryCollection, retrieveQuery,
|
||||||
ResponseEntity<List<RecallRetrievalResp>> embeddingResponseEntity =
|
solvedQueryResultNum);
|
||||||
restTemplate.exchange(requestUrl, HttpMethod.POST, entity,
|
|
||||||
new ParameterizedTypeReference<List<RecallRetrievalResp>>() {
|
log.info("[embedding] recognize result body:{}", resultList);
|
||||||
});
|
|
||||||
log.info("[embedding] recognize result body:{}", embeddingResponseEntity);
|
|
||||||
List<RecallRetrievalResp> embeddingResps = embeddingResponseEntity.getBody();
|
|
||||||
Set<String> querySet = new HashSet<>();
|
Set<String> querySet = new HashSet<>();
|
||||||
if (CollectionUtils.isNotEmpty(embeddingResps)) {
|
if (CollectionUtils.isNotEmpty(resultList)) {
|
||||||
for (RecallRetrievalResp embeddingResp : embeddingResps) {
|
for (RetrieveQueryResult retrieveQueryResult : resultList) {
|
||||||
List<RecallRetrieval> embeddingRetrievals = embeddingResp.getRetrieval();
|
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||||
for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) {
|
for (Retrieval retrieval : retrievals) {
|
||||||
if (queryText.equalsIgnoreCase(embeddingRetrieval.getQuery())) {
|
if (queryText.equalsIgnoreCase(retrieval.getQuery())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (querySet.contains(embeddingRetrieval.getQuery())) {
|
if (querySet.contains(retrieval.getQuery())) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
String id = embeddingRetrieval.getId();
|
String id = retrieval.getId();
|
||||||
SolvedQueryRecallResp solvedQueryRecallResp = SolvedQueryRecallResp.builder()
|
SolvedQueryRecallResp solvedQueryRecallResp = SolvedQueryRecallResp.builder()
|
||||||
.queryText(embeddingRetrieval.getQuery())
|
.queryText(retrieval.getQuery())
|
||||||
.queryId(getQueryId(id)).parseId(getParseId(id))
|
.queryId(getQueryId(id)).parseId(getParseId(id))
|
||||||
.build();
|
.build();
|
||||||
solvedQueryRecallResps.add(solvedQueryRecallResp);
|
solvedQueryRecallResps.add(solvedQueryRecallResp);
|
||||||
querySet.add(embeddingRetrieval.getQuery());
|
querySet.add(retrieval.getQuery());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.warn("recall similar solved query failed, queryText:{}", queryText);
|
log.warn("recall similar solved query failed, queryText:{}", queryText);
|
||||||
}
|
}
|
||||||
@@ -146,7 +143,8 @@ public class SolvedQueryManager {
|
|||||||
log.info("[embedding] request body :{}, url:{}", jsonBody, url);
|
log.info("[embedding] request body :{}, url:{}", jsonBody, url);
|
||||||
RestTemplate restTemplate = new RestTemplate();
|
RestTemplate restTemplate = new RestTemplate();
|
||||||
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
|
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
|
||||||
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {});
|
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {
|
||||||
|
});
|
||||||
log.info("[embedding] result body:{}", responseEntity);
|
log.info("[embedding] result body:{}", responseEntity);
|
||||||
return responseEntity;
|
return responseEntity;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
|||||||
@@ -14,22 +14,21 @@ public class EmbeddingConfig {
|
|||||||
@Value("${embedding.recognize.path:/preset_query_retrival}")
|
@Value("${embedding.recognize.path:/preset_query_retrival}")
|
||||||
private String recognizePath;
|
private String recognizePath;
|
||||||
|
|
||||||
@Value("${embedding.delete.path:/preset_delete_by_ids}")
|
@Value("${embedding.preset.collection:preset_query_collection}")
|
||||||
private String deletePath;
|
private String presetCollection;
|
||||||
|
|
||||||
@Value("${embedding.add.path:/preset_query_add}")
|
|
||||||
private String addPath;
|
|
||||||
|
|
||||||
@Value("${embedding.nResult:1}")
|
@Value("${embedding.nResult:1}")
|
||||||
private String nResult;
|
private int nResult;
|
||||||
|
|
||||||
@Value("${embedding.solvedQuery.recall.path:/solved_query_retrival}")
|
@Value("${embedding.solved.query.collection:solved_query_collection}")
|
||||||
private String solvedQueryRecallPath;
|
private String solvedQueryCollection;
|
||||||
|
|
||||||
@Value("${embedding.solvedQuery.add.path:/solved_query_add}")
|
|
||||||
private String solvedQueryAddPath;
|
|
||||||
|
|
||||||
@Value("${embedding.solved.query.nResult:5}")
|
@Value("${embedding.solved.query.nResult:5}")
|
||||||
private String solvedQueryResultNum;
|
private int solvedQueryResultNum;
|
||||||
|
|
||||||
|
@Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}")
|
||||||
|
private String metricAnalyzeQueryCollection;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ public class EmbeddingUtils {
|
|||||||
return embeddingCollections.stream().map(EmbeddingCollection::getName).collect(Collectors.toList());
|
return embeddingCollections.stream().map(EmbeddingCollection::getName).collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
private ResponseEntity doRequest(String url, String jsonBody, HttpMethod httpMethod) {
|
public ResponseEntity doRequest(String url, String jsonBody, HttpMethod httpMethod) {
|
||||||
try {
|
try {
|
||||||
HttpHeaders headers = new HttpHeaders();
|
HttpHeaders headers = new HttpHeaders();
|
||||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
@@ -94,11 +94,12 @@ public class EmbeddingUtils {
|
|||||||
entity = new HttpEntity<>(jsonBody, headers);
|
entity = new HttpEntity<>(jsonBody, headers);
|
||||||
}
|
}
|
||||||
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
|
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
|
||||||
httpMethod, entity, new ParameterizedTypeReference<String>() {});
|
httpMethod, entity, new ParameterizedTypeReference<String>() {
|
||||||
|
});
|
||||||
log.info("[embedding] url :{} result body:{}", url, responseEntity);
|
log.info("[embedding] url :{} result body:{}", url, responseEntity);
|
||||||
return responseEntity;
|
return responseEntity;
|
||||||
} catch (Throwable e) {
|
} catch (Throwable e) {
|
||||||
log.warn("connect to embedding service failed, url:{}", url);
|
log.warn("doRequest service failed, url:" + url, e);
|
||||||
}
|
}
|
||||||
return ResponseEntity.of(Optional.empty());
|
return ResponseEntity.of(Optional.empty());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,21 +2,12 @@ package com.tencent.supersonic.integration;
|
|||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.tencent.supersonic.StandaloneLauncher;
|
import com.tencent.supersonic.StandaloneLauncher;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
|
||||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.analytics.LLMAnswerResp;
|
import com.tencent.supersonic.chat.query.llm.analytics.LLMAnswerResp;
|
||||||
import com.tencent.supersonic.chat.service.AgentService;
|
import com.tencent.supersonic.chat.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.service.QueryService;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.util.DataUtils;
|
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.beans.factory.annotation.Qualifier;
|
|
||||||
import org.springframework.boot.test.context.SpringBootTest;
|
import org.springframework.boot.test.context.SpringBootTest;
|
||||||
import org.springframework.boot.test.mock.mockito.MockBean;
|
import org.springframework.boot.test.mock.mockito.MockBean;
|
||||||
import org.springframework.http.ResponseEntity;
|
import org.springframework.http.ResponseEntity;
|
||||||
@@ -31,37 +22,21 @@ public class MetricInterpretTest {
|
|||||||
@MockBean
|
@MockBean
|
||||||
private AgentService agentService;
|
private AgentService agentService;
|
||||||
|
|
||||||
@MockBean
|
|
||||||
private PluginManager pluginManager;
|
|
||||||
|
|
||||||
@MockBean
|
@MockBean
|
||||||
private EmbeddingConfig embeddingConfig;
|
private EmbeddingConfig embeddingConfig;
|
||||||
|
|
||||||
@Autowired
|
@MockBean
|
||||||
@Qualifier("chatQueryService")
|
private EmbeddingUtils embeddingUtils;
|
||||||
private QueryService queryService;
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMetricInterpret() throws Exception {
|
public void testMetricInterpret() throws Exception {
|
||||||
MockConfiguration.mockAgent(agentService);
|
MockConfiguration.mockAgent(agentService);
|
||||||
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
|
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
|
||||||
|
|
||||||
LLMAnswerResp lLmAnswerResp = new LLMAnswerResp();
|
LLMAnswerResp lLmAnswerResp = new LLMAnswerResp();
|
||||||
lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多");
|
lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多");
|
||||||
MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
|
|
||||||
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
|
|
||||||
QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况",
|
|
||||||
DataUtils.getAgent().getId());
|
|
||||||
|
|
||||||
ParseResp parseResp = queryService.performParsing(queryReq);
|
MockConfiguration.embeddingUtils(embeddingUtils, ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
|
||||||
ExecuteQueryReq executeReq = ExecuteQueryReq.builder().user(queryReq.getUser())
|
|
||||||
.chatId(parseResp.getChatId())
|
|
||||||
.queryId(parseResp.getQueryId())
|
|
||||||
.queryText(parseResp.getQueryText())
|
|
||||||
.parseInfo(parseResp.getCandidateParses().get(0))
|
|
||||||
.parseId(parseResp.getCandidateParses().get(0).getId())
|
|
||||||
.build();
|
|
||||||
QueryResult queryResult = queryService.performExecution(executeReq);
|
|
||||||
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistantMessage());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,30 +1,30 @@
|
|||||||
package com.tencent.supersonic.integration;
|
package com.tencent.supersonic.integration;
|
||||||
|
|
||||||
|
|
||||||
|
import static org.mockito.ArgumentMatchers.anyObject;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
|
||||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp;
|
|
||||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
|
|
||||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||||
import com.tencent.supersonic.chat.service.AgentService;
|
import com.tencent.supersonic.chat.service.AgentService;
|
||||||
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
import com.tencent.supersonic.util.DataUtils;
|
import com.tencent.supersonic.util.DataUtils;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.context.annotation.Configuration;
|
import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.http.ResponseEntity;
|
import org.springframework.http.ResponseEntity;
|
||||||
import static org.mockito.ArgumentMatchers.eq;
|
|
||||||
import static org.mockito.ArgumentMatchers.notNull;
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class MockConfiguration {
|
public class MockConfiguration {
|
||||||
|
|
||||||
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
|
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
|
||||||
RecallRetrievalResp embeddingResp = new RecallRetrievalResp();
|
RetrieveQueryResult embeddingResp = new RetrieveQueryResult();
|
||||||
RecallRetrieval embeddingRetrieval = new RecallRetrieval();
|
Retrieval embeddingRetrieval = new Retrieval();
|
||||||
embeddingRetrieval.setId(id);
|
embeddingRetrieval.setId(id);
|
||||||
embeddingRetrieval.setPresetId(id);
|
embeddingRetrieval.setDistance(0.15);
|
||||||
embeddingRetrieval.setDistance("0.15");
|
|
||||||
embeddingResp.setQuery(text);
|
embeddingResp.setQuery(text);
|
||||||
embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval));
|
embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval));
|
||||||
when(pluginManager.recognize(text)).thenReturn(embeddingResp);
|
when(pluginManager.recognize(text)).thenReturn(embeddingResp);
|
||||||
@@ -34,13 +34,11 @@ public class MockConfiguration {
|
|||||||
when(embeddingConfig.getUrl()).thenReturn("test");
|
when(embeddingConfig.getUrl()).thenReturn("test");
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path,
|
|
||||||
ResponseEntity<String> responseEntity) {
|
|
||||||
when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void mockAgent(AgentService agentService) {
|
public static void mockAgent(AgentService agentService) {
|
||||||
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
|
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static void embeddingUtils(EmbeddingUtils embeddingUtils, ResponseEntity<String> responseEntity) {
|
||||||
|
when(embeddingUtils.doRequest(anyObject(), anyObject(), anyObject())).thenReturn(responseEntity);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user