From ed0f856438acbbc0def096a2bdf63de93c6ffbce Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Wed, 6 Dec 2023 14:50:57 +0800 Subject: [PATCH] [improvement](chat) Unified vector-related interfaces to go through EmbeddingUtils. (#476) --- .../embedding/EmbeddingRecallParser.java | 19 ++- .../supersonic/chat/plugin/PluginManager.java | 146 ++++++++---------- .../llm/analytics/MetricAnalyzeQuery.java | 16 +- .../chat/service/impl/QueryServiceImpl.java | 2 +- .../chat/utils/SolvedQueryManager.java | 98 ++++++------ .../common/config/EmbeddingConfig.java | 23 ++- .../common/util/embedding/EmbeddingUtils.java | 7 +- .../integration/MetricInterpretTest.java | 37 +---- .../integration/MockConfiguration.java | 28 ++-- 9 files changed, 168 insertions(+), 208 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingRecallParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingRecallParser.java index 0ce8732c9..b0661e2d5 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingRecallParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/plugin/embedding/EmbeddingRecallParser.java @@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.parser.plugin.embedding; import com.google.common.collect.Lists; import com.tencent.supersonic.chat.api.pojo.QueryContext; -import com.tencent.supersonic.chat.parser.PythonLLMProxy; import com.tencent.supersonic.chat.parser.LLMProxy; +import com.tencent.supersonic.chat.parser.PythonLLMProxy; import com.tencent.supersonic.chat.parser.plugin.ParseMode; import com.tencent.supersonic.chat.parser.plugin.PluginParser; import com.tencent.supersonic.chat.plugin.Plugin; @@ -12,6 +12,8 @@ import com.tencent.supersonic.chat.plugin.PluginRecallResult; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.embedding.Retrieval; +import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; import java.util.Comparator; import java.util.List; import java.util.Map; @@ -40,13 +42,13 @@ public class EmbeddingRecallParser extends PluginParser { @Override public PluginRecallResult recallPlugin(QueryContext queryContext) { String text = queryContext.getRequest().getQueryText(); - List embeddingRetrievals = embeddingRecall(text); + List embeddingRetrievals = embeddingRecall(text); if (CollectionUtils.isEmpty(embeddingRetrievals)) { return null; } List plugins = getPluginList(queryContext); Map pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p)); - for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) { + for (Retrieval embeddingRetrieval : embeddingRetrievals) { Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId())); if (plugin == null) { continue; @@ -59,7 +61,7 @@ public class EmbeddingRecallParser extends PluginParser { continue; } plugin.setParseMode(ParseMode.EMBEDDING_RECALL); - double distance = Double.parseDouble(embeddingRetrieval.getDistance()); + double distance = embeddingRetrieval.getDistance(); double score = queryContext.getRequest().getQueryText().length() * (1 - distance); return PluginRecallResult.builder() .plugin(plugin).modelIds(modelList).score(score).distance(distance).build(); @@ -68,14 +70,15 @@ public class EmbeddingRecallParser extends PluginParser { return null; } - public List embeddingRecall(String embeddingText) { + public List embeddingRecall(String embeddingText) { try { PluginManager pluginManager = ContextUtils.getBean(PluginManager.class); - RecallRetrievalResp embeddingResp = pluginManager.recognize(embeddingText); - List embeddingRetrievals = embeddingResp.getRetrieval(); + RetrieveQueryResult embeddingResp = pluginManager.recognize(embeddingText); + + List embeddingRetrievals = embeddingResp.getRetrieval(); if (!CollectionUtils.isEmpty(embeddingRetrievals)) { embeddingRetrievals = embeddingRetrievals.stream().sorted(Comparator.comparingDouble(o -> - Math.abs(Double.parseDouble(o.getDistance())))).collect(Collectors.toList()); + Math.abs(o.getDistance()))).collect(Collectors.toList()); embeddingResp.setRetrieval(embeddingRetrievals); } return embeddingRetrievals; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java b/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java index 34cfb2bd2..4c16ab4cb 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java @@ -3,17 +3,14 @@ package com.tencent.supersonic.chat.plugin; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import com.tencent.supersonic.chat.api.pojo.QueryContext; -import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; -import com.tencent.supersonic.chat.api.pojo.SchemaElementType; -import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.agent.AgentToolType; import com.tencent.supersonic.chat.agent.PluginTool; -import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp; -import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.chat.api.pojo.SchemaElementType; +import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; import com.tencent.supersonic.chat.plugin.event.PluginAddEvent; import com.tencent.supersonic.chat.plugin.event.PluginDelEvent; import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent; @@ -21,31 +18,28 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption; import com.tencent.supersonic.chat.query.plugin.WebBase; import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.PluginService; +import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.util.ContextUtils; -import java.net.URI; -import java.util.List; +import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; +import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; +import com.tencent.supersonic.common.util.embedding.Retrieval; +import com.tencent.supersonic.common.util.embedding.RetrieveQuery; +import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; +import java.util.ArrayList; import java.util.Collection; -import java.util.Set; -import java.util.Optional; +import java.util.Collections; import java.util.HashSet; -import java.util.HashMap; -import java.util.Objects; +import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.tuple.Pair; -import org.apache.logging.log4j.util.Strings; import org.springframework.context.event.EventListener; -import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.ResponseEntity; -import org.springframework.http.HttpMethod; -import org.springframework.http.HttpHeaders; -import org.springframework.http.MediaType; -import org.springframework.http.HttpEntity; import org.springframework.stereotype.Component; import org.springframework.web.client.RestTemplate; -import org.springframework.web.util.UriComponentsBuilder; @Slf4j @Component @@ -55,9 +49,12 @@ public class PluginManager { private RestTemplate restTemplate; - public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate) { + private EmbeddingUtils embeddingUtils; + + public PluginManager(EmbeddingConfig embeddingConfig, RestTemplate restTemplate, EmbeddingUtils embeddingUtils) { this.embeddingConfig = embeddingConfig; this.restTemplate = restTemplate; + this.embeddingUtils = embeddingUtils; } public static List getPluginAgentCanSupport(Integer agentId) { @@ -124,96 +121,77 @@ public class PluginManager { } } - public void requestEmbeddingPluginDelete(Set ids) { - if (CollectionUtils.isEmpty(ids)) { + public void requestEmbeddingPluginDelete(Set queryIds) { + if (CollectionUtils.isEmpty(queryIds)) { return; } - doRequest(embeddingConfig.getDeletePath(), JSONObject.toJSONString(ids)); + String presetCollection = embeddingConfig.getPresetCollection(); + + List queries = new ArrayList<>(); + for (String id : queryIds) { + EmbeddingQuery embeddingQuery = new EmbeddingQuery(); + embeddingQuery.setQueryId(id); + queries.add(embeddingQuery); + } + embeddingUtils.deleteQuery(presetCollection, queries); } - public void requestEmbeddingPluginAdd(List> maps) { - if (CollectionUtils.isEmpty(maps)) { + public void requestEmbeddingPluginAdd(List queries) { + if (CollectionUtils.isEmpty(queries)) { return; } - doRequest(embeddingConfig.getAddPath(), JSONObject.toJSONString(maps)); - } - - public ResponseEntity doRequest(String path, String jsonBody) { - if (Strings.isEmpty(embeddingConfig.getUrl())) { - return ResponseEntity.of(Optional.empty()); - } - String url = embeddingConfig.getUrl() + path; - try { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - headers.setLocation(URI.create(url)); - URI requestUrl = UriComponentsBuilder - .fromHttpUrl(url).build().encode().toUri(); - HttpEntity entity = new HttpEntity<>(jsonBody, headers); - log.info("[embedding] equest body :{}, url:{}", jsonBody, url); - ResponseEntity responseEntity = restTemplate.exchange(requestUrl, - HttpMethod.POST, entity, new ParameterizedTypeReference() {}); - log.info("[embedding] result body:{}", responseEntity); - return responseEntity; - } catch (Throwable e) { - log.warn("connect to embedding service failed, url:{}", url); - } - return ResponseEntity.of(Optional.empty()); + String presetCollection = embeddingConfig.getPresetCollection(); + embeddingUtils.addQuery(presetCollection, queries); } public void requestEmbeddingPluginAddALL(List plugins) { requestEmbeddingPluginAdd(convert(plugins)); } - public RecallRetrievalResp recognize(String embeddingText) { - String url = embeddingConfig.getUrl() + embeddingConfig.getRecognizePath() + "?n_results=" - + embeddingConfig.getNResult(); - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - headers.setLocation(URI.create(url)); - URI requestUrl = UriComponentsBuilder - .fromHttpUrl(url).build().encode().toUri(); - String jsonBody = JSONObject.toJSONString(Lists.newArrayList(embeddingText)); - HttpEntity entity = new HttpEntity<>(jsonBody, headers); - log.info("[embedding] request body:{}, url:{}", jsonBody, url); - ResponseEntity> embeddingResponseEntity = - restTemplate.exchange(requestUrl, HttpMethod.POST, entity, - new ParameterizedTypeReference>() { - }); - log.info("[embedding] recognize result body:{}", embeddingResponseEntity); - List embeddingResps = embeddingResponseEntity.getBody(); - if (CollectionUtils.isNotEmpty(embeddingResps)) { - for (RecallRetrievalResp embeddingResp : embeddingResps) { - List embeddingRetrievals = embeddingResp.getRetrieval(); - for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) { + public RetrieveQueryResult recognize(String embeddingText) { + + EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class); + + RetrieveQuery retrieveQuery = RetrieveQuery.builder() + .queryTextsList(Collections.singletonList(embeddingText)) + .build(); + + List resultList = embeddingUtils.retrieveQuery(embeddingConfig.getPresetCollection(), + retrieveQuery, embeddingConfig.getNResult()); + + + if (CollectionUtils.isNotEmpty(resultList)) { + for (RetrieveQueryResult embeddingResp : resultList) { + List embeddingRetrievals = embeddingResp.getRetrieval(); + for (Retrieval embeddingRetrieval : embeddingRetrievals) { embeddingRetrieval.setId(getPluginIdFromEmbeddingId(embeddingRetrieval.getId())); } } - return embeddingResps.get(0); + return resultList.get(0); } throw new RuntimeException("get embedding result failed"); } - public List> convert(List plugins) { - List> maps = Lists.newArrayList(); + public List convert(List plugins) { + List queries = Lists.newArrayList(); for (Plugin plugin : plugins) { List exampleQuestions = plugin.getExampleQuestionList(); int num = 0; for (String pattern : exampleQuestions) { - Map map = new HashMap<>(); - map.put("preset_query_id", generateUniqueEmbeddingId(num, plugin.getId())); - map.put("preset_query", pattern); - maps.add(map); + EmbeddingQuery query = new EmbeddingQuery(); + query.setQueryId(generateUniqueEmbeddingId(num, plugin.getId())); + query.setQuery(pattern); + queries.add(query); num++; } } - return maps; + return queries; } private Set getEmbeddingId(List plugins) { Set embeddingIdSet = new HashSet<>(); - for (Map map : convert(plugins)) { - embeddingIdSet.add(map.get("preset_query_id")); + for (EmbeddingQuery query : convert(plugins)) { + embeddingIdSet.add(query.getQueryId()); } return embeddingIdSet; } @@ -243,7 +221,7 @@ public class PluginManager { } Set matchedModel = Sets.newHashSet(); Map> paramOptionMap = paramOptions.stream() - .collect(Collectors.groupingBy(ParamOption::getModelId)); + .collect(Collectors.groupingBy(ParamOption::getModelId)); for (Long modelId : paramOptionMap.keySet()) { List params = paramOptionMap.get(modelId); if (CollectionUtils.isEmpty(params)) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/analytics/MetricAnalyzeQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/analytics/MetricAnalyzeQuery.java index 6b21cbafa..581c3a02f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/analytics/MetricAnalyzeQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/analytics/MetricAnalyzeQuery.java @@ -10,16 +10,17 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; import com.tencent.supersonic.chat.config.OptimizationConfig; -import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.llm.LLMSemanticQuery; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.QueryReqBuilder; +import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.pojo.QueryType; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import java.util.HashMap; @@ -30,6 +31,7 @@ import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.commons.lang3.StringUtils; +import org.springframework.http.HttpMethod; import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; @@ -149,12 +151,18 @@ public class MetricAnalyzeQuery extends LLMSemanticQuery { } public String fetchInterpret(String queryText, String dataText) { - PluginManager pluginManager = ContextUtils.getBean(PluginManager.class); LLMAnswerReq lLmAnswerReq = new LLMAnswerReq(); lLmAnswerReq.setQueryText(queryText); lLmAnswerReq.setPluginOutput(dataText); - ResponseEntity responseEntity = pluginManager.doRequest("answer_with_plugin_call", - JSONObject.toJSONString(lLmAnswerReq)); + EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class); + EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); + String metricAnalyzeQueryCollection = embeddingConfig.getMetricAnalyzeQueryCollection(); + + String url = String.format("%s/retrieve_query?collection_name=%s", embeddingConfig.getUrl(), + metricAnalyzeQueryCollection); + + ResponseEntity responseEntity = embeddingUtils.doRequest(url, JSONObject.toJSONString(lLmAnswerReq), + HttpMethod.POST); LLMAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLMAnswerResp.class); if (lLmAnswerResp != null) { return lLmAnswerResp.getAssistantMessage(); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index b61f2984c..3be90e45b 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -199,7 +199,7 @@ public class QueryServiceImpl implements QueryService { queryReq.getUser().getName(), queryReq.getChatId().longValue()); queryResult.setChatContext(parseInfo); // update chat context after a successful semantic query - if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) { + if (QueryState.SUCCESS.equals(queryResult.getQueryState())) { chatCtx.setParseInfo(parseInfo); chatService.updateContext(chatCtx); saveSolvedQuery(queryReq, parseInfo, chatQueryDO, queryResult); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/SolvedQueryManager.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/SolvedQueryManager.java index c2a9eed36..82817fb55 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/SolvedQueryManager.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/SolvedQueryManager.java @@ -1,13 +1,21 @@ package com.tencent.supersonic.chat.utils; -import com.alibaba.fastjson.JSONObject; -import com.alibaba.fastjson.serializer.SerializerFeature; import com.google.common.collect.Lists; import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq; import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp; -import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval; +import com.tencent.supersonic.common.util.embedding.EmbeddingQuery; +import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; +import com.tencent.supersonic.common.util.embedding.Retrieval; +import com.tencent.supersonic.common.util.embedding.RetrieveQuery; +import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; +import java.net.URI; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; @@ -21,13 +29,6 @@ import org.springframework.http.ResponseEntity; import org.springframework.stereotype.Component; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriComponentsBuilder; -import java.net.URI; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; @Slf4j @Component @@ -35,8 +36,11 @@ public class SolvedQueryManager { private EmbeddingConfig embeddingConfig; - public SolvedQueryManager(EmbeddingConfig embeddingConfig) { + private EmbeddingUtils embeddingUtils; + + public SolvedQueryManager(EmbeddingConfig embeddingConfig, EmbeddingUtils embeddingUtils) { this.embeddingConfig = embeddingConfig; + this.embeddingUtils = embeddingUtils; } public void saveSolvedQuery(SolvedQueryReq solvedQueryReq) { @@ -46,15 +50,16 @@ public class SolvedQueryManager { String queryText = solvedQueryReq.getQueryText(); try { String uniqueId = generateUniqueId(solvedQueryReq.getQueryId(), solvedQueryReq.getParseId()); - Map requestMap = new HashMap<>(); - requestMap.put("query", queryText); - requestMap.put("query_id", uniqueId); - Map metaData = new HashMap<>(); + EmbeddingQuery embeddingQuery = new EmbeddingQuery(); + embeddingQuery.setQueryId(uniqueId); + embeddingQuery.setQuery(queryText); + + Map metaData = new HashMap<>(); metaData.put("modelId", String.valueOf(solvedQueryReq.getModelId())); metaData.put("agentId", String.valueOf(solvedQueryReq.getAgentId())); - requestMap.put("metadata", metaData); - doRequest(embeddingConfig.getSolvedQueryAddPath(), - JSONObject.toJSONString(Lists.newArrayList(requestMap))); + embeddingQuery.setMetadata(metaData); + String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection(); + embeddingUtils.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery)); } catch (Exception e) { log.warn("save history question to embedding failed, queryText:{}", queryText, e); } @@ -66,49 +71,41 @@ public class SolvedQueryManager { } List solvedQueryRecallResps = Lists.newArrayList(); try { - String url = embeddingConfig.getUrl() + embeddingConfig.getSolvedQueryRecallPath() + "?n_results=" - + embeddingConfig.getSolvedQueryResultNum(); - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); - headers.setLocation(URI.create(url)); - URI requestUrl = UriComponentsBuilder - .fromHttpUrl(url).build().encode().toUri(); - Map map = new HashMap<>(); - map.put("queryTextsList", Lists.newArrayList(queryText)); - Map filterCondition = new HashMap<>(); + String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection(); + int solvedQueryResultNum = embeddingConfig.getSolvedQueryResultNum(); + + Map filterCondition = new HashMap<>(); filterCondition.put("agentId", String.valueOf(agentId)); - map.put("filterCondition", filterCondition); - String jsonBody = JSONObject.toJSONString(map, SerializerFeature.WriteMapNullValue); - HttpEntity entity = new HttpEntity<>(jsonBody, headers); - log.info("[embedding] request body:{}, url:{}", jsonBody, url); - RestTemplate restTemplate = new RestTemplate(); - ResponseEntity> embeddingResponseEntity = - restTemplate.exchange(requestUrl, HttpMethod.POST, entity, - new ParameterizedTypeReference>() { - }); - log.info("[embedding] recognize result body:{}", embeddingResponseEntity); - List embeddingResps = embeddingResponseEntity.getBody(); + RetrieveQuery retrieveQuery = RetrieveQuery.builder() + .queryTextsList(Lists.newArrayList(queryText)) + .filterCondition(filterCondition) + .build(); + List resultList = embeddingUtils.retrieveQuery(solvedQueryCollection, retrieveQuery, + solvedQueryResultNum); + + log.info("[embedding] recognize result body:{}", resultList); Set querySet = new HashSet<>(); - if (CollectionUtils.isNotEmpty(embeddingResps)) { - for (RecallRetrievalResp embeddingResp : embeddingResps) { - List embeddingRetrievals = embeddingResp.getRetrieval(); - for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) { - if (queryText.equalsIgnoreCase(embeddingRetrieval.getQuery())) { + if (CollectionUtils.isNotEmpty(resultList)) { + for (RetrieveQueryResult retrieveQueryResult : resultList) { + List retrievals = retrieveQueryResult.getRetrieval(); + for (Retrieval retrieval : retrievals) { + if (queryText.equalsIgnoreCase(retrieval.getQuery())) { continue; } - if (querySet.contains(embeddingRetrieval.getQuery())) { + if (querySet.contains(retrieval.getQuery())) { continue; } - String id = embeddingRetrieval.getId(); + String id = retrieval.getId(); SolvedQueryRecallResp solvedQueryRecallResp = SolvedQueryRecallResp.builder() - .queryText(embeddingRetrieval.getQuery()) + .queryText(retrieval.getQuery()) .queryId(getQueryId(id)).parseId(getParseId(id)) .build(); solvedQueryRecallResps.add(solvedQueryRecallResp); - querySet.add(embeddingRetrieval.getQuery()); + querySet.add(retrieval.getQuery()); } } } + } catch (Exception e) { log.warn("recall similar solved query failed, queryText:{}", queryText); } @@ -146,7 +143,8 @@ public class SolvedQueryManager { log.info("[embedding] request body :{}, url:{}", jsonBody, url); RestTemplate restTemplate = new RestTemplate(); ResponseEntity responseEntity = restTemplate.exchange(requestUrl, - HttpMethod.POST, entity, new ParameterizedTypeReference() {}); + HttpMethod.POST, entity, new ParameterizedTypeReference() { + }); log.info("[embedding] result body:{}", responseEntity); return responseEntity; } catch (Exception e) { diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java index 6bd1483f6..e4a4169f6 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java @@ -14,22 +14,21 @@ public class EmbeddingConfig { @Value("${embedding.recognize.path:/preset_query_retrival}") private String recognizePath; - @Value("${embedding.delete.path:/preset_delete_by_ids}") - private String deletePath; - - @Value("${embedding.add.path:/preset_query_add}") - private String addPath; + @Value("${embedding.preset.collection:preset_query_collection}") + private String presetCollection; @Value("${embedding.nResult:1}") - private String nResult; + private int nResult; - @Value("${embedding.solvedQuery.recall.path:/solved_query_retrival}") - private String solvedQueryRecallPath; - - @Value("${embedding.solvedQuery.add.path:/solved_query_add}") - private String solvedQueryAddPath; + @Value("${embedding.solved.query.collection:solved_query_collection}") + private String solvedQueryCollection; @Value("${embedding.solved.query.nResult:5}") - private String solvedQueryResultNum; + private int solvedQueryResultNum; + + @Value("${embedding.metric.analyzeQuery.collection:solved_query_collection}") + private String metricAnalyzeQueryCollection; + + } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingUtils.java index d186e9c9c..90166ac65 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/embedding/EmbeddingUtils.java @@ -81,7 +81,7 @@ public class EmbeddingUtils { return embeddingCollections.stream().map(EmbeddingCollection::getName).collect(Collectors.toList()); } - private ResponseEntity doRequest(String url, String jsonBody, HttpMethod httpMethod) { + public ResponseEntity doRequest(String url, String jsonBody, HttpMethod httpMethod) { try { HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.APPLICATION_JSON); @@ -94,11 +94,12 @@ public class EmbeddingUtils { entity = new HttpEntity<>(jsonBody, headers); } ResponseEntity responseEntity = restTemplate.exchange(requestUrl, - httpMethod, entity, new ParameterizedTypeReference() {}); + httpMethod, entity, new ParameterizedTypeReference() { + }); log.info("[embedding] url :{} result body:{}", url, responseEntity); return responseEntity; } catch (Throwable e) { - log.warn("connect to embedding service failed, url:{}", url); + log.warn("doRequest service failed, url:" + url, e); } return ResponseEntity.of(Optional.empty()); } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java index e4342b819..d2cbadb98 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java @@ -2,21 +2,12 @@ package com.tencent.supersonic.integration; import com.alibaba.fastjson.JSONObject; import com.tencent.supersonic.StandaloneLauncher; -import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; -import com.tencent.supersonic.chat.api.pojo.request.QueryReq; -import com.tencent.supersonic.chat.api.pojo.response.ParseResp; -import com.tencent.supersonic.chat.api.pojo.response.QueryResult; -import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.query.llm.analytics.LLMAnswerResp; import com.tencent.supersonic.chat.service.AgentService; -import com.tencent.supersonic.chat.service.QueryService; -import com.tencent.supersonic.util.DataUtils; -import org.junit.Assert; +import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; import org.junit.Test; import org.junit.runner.RunWith; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.mock.mockito.MockBean; import org.springframework.http.ResponseEntity; @@ -31,37 +22,21 @@ public class MetricInterpretTest { @MockBean private AgentService agentService; - @MockBean - private PluginManager pluginManager; - @MockBean private EmbeddingConfig embeddingConfig; - @Autowired - @Qualifier("chatQueryService") - private QueryService queryService; + @MockBean + private EmbeddingUtils embeddingUtils; @Test public void testMetricInterpret() throws Exception { MockConfiguration.mockAgent(agentService); MockConfiguration.mockEmbeddingUrl(embeddingConfig); + LLMAnswerResp lLmAnswerResp = new LLMAnswerResp(); lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多"); - MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call", - ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp))); - QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况", - DataUtils.getAgent().getId()); - ParseResp parseResp = queryService.performParsing(queryReq); - ExecuteQueryReq executeReq = ExecuteQueryReq.builder().user(queryReq.getUser()) - .chatId(parseResp.getChatId()) - .queryId(parseResp.getQueryId()) - .queryText(parseResp.getQueryText()) - .parseInfo(parseResp.getCandidateParses().get(0)) - .parseId(parseResp.getCandidateParses().get(0).getId()) - .build(); - QueryResult queryResult = queryService.performExecution(executeReq); - Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistantMessage()); + MockConfiguration.embeddingUtils(embeddingUtils, ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp))); } } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java index 999b0788d..a7ea37da1 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java @@ -1,30 +1,30 @@ package com.tencent.supersonic.integration; +import static org.mockito.ArgumentMatchers.anyObject; +import static org.mockito.Mockito.when; + import com.google.common.collect.Lists; -import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp; -import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval; import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.service.AgentService; +import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.util.embedding.EmbeddingUtils; +import com.tencent.supersonic.common.util.embedding.Retrieval; +import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; import com.tencent.supersonic.util.DataUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.context.annotation.Configuration; import org.springframework.http.ResponseEntity; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.notNull; -import static org.mockito.Mockito.when; @Configuration @Slf4j public class MockConfiguration { public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) { - RecallRetrievalResp embeddingResp = new RecallRetrievalResp(); - RecallRetrieval embeddingRetrieval = new RecallRetrieval(); + RetrieveQueryResult embeddingResp = new RetrieveQueryResult(); + Retrieval embeddingRetrieval = new Retrieval(); embeddingRetrieval.setId(id); - embeddingRetrieval.setPresetId(id); - embeddingRetrieval.setDistance("0.15"); + embeddingRetrieval.setDistance(0.15); embeddingResp.setQuery(text); embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval)); when(pluginManager.recognize(text)).thenReturn(embeddingResp); @@ -34,13 +34,11 @@ public class MockConfiguration { when(embeddingConfig.getUrl()).thenReturn("test"); } - public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path, - ResponseEntity responseEntity) { - when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity); - } - public static void mockAgent(AgentService agentService) { when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent()); } + public static void embeddingUtils(EmbeddingUtils embeddingUtils, ResponseEntity responseEntity) { + when(embeddingUtils.doRequest(anyObject(), anyObject(), anyObject())).thenReturn(responseEntity); + } }