[improvement](chat) Unified vector-related interfaces to go through EmbeddingUtils. (#476)

This commit is contained in:
lexluo09
2023-12-06 14:50:57 +08:00
committed by GitHub
parent 9aa5c93d9d
commit ed0f856438
9 changed files with 168 additions and 208 deletions

View File

@@ -2,21 +2,12 @@ package com.tencent.supersonic.integration;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.StandaloneLauncher;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.query.llm.analytics.LLMAnswerResp;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.chat.service.QueryService;
import com.tencent.supersonic.util.DataUtils;
import org.junit.Assert;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.http.ResponseEntity;
@@ -31,37 +22,21 @@ public class MetricInterpretTest {
@MockBean
private AgentService agentService;
@MockBean
private PluginManager pluginManager;
@MockBean
private EmbeddingConfig embeddingConfig;
@Autowired
@Qualifier("chatQueryService")
private QueryService queryService;
@MockBean
private EmbeddingUtils embeddingUtils;
@Test
public void testMetricInterpret() throws Exception {
MockConfiguration.mockAgent(agentService);
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
LLMAnswerResp lLmAnswerResp = new LLMAnswerResp();
lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多");
MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况",
DataUtils.getAgent().getId());
ParseResp parseResp = queryService.performParsing(queryReq);
ExecuteQueryReq executeReq = ExecuteQueryReq.builder().user(queryReq.getUser())
.chatId(parseResp.getChatId())
.queryId(parseResp.getQueryId())
.queryText(parseResp.getQueryText())
.parseInfo(parseResp.getCandidateParses().get(0))
.parseId(parseResp.getCandidateParses().get(0).getId())
.build();
QueryResult queryResult = queryService.performExecution(executeReq);
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistantMessage());
MockConfiguration.embeddingUtils(embeddingUtils, ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
}
}

View File

@@ -1,30 +1,30 @@
package com.tencent.supersonic.integration;
import static org.mockito.ArgumentMatchers.anyObject;
import static org.mockito.Mockito.when;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.service.AgentService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.ResponseEntity;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.notNull;
import static org.mockito.Mockito.when;
@Configuration
@Slf4j
public class MockConfiguration {
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
RecallRetrievalResp embeddingResp = new RecallRetrievalResp();
RecallRetrieval embeddingRetrieval = new RecallRetrieval();
RetrieveQueryResult embeddingResp = new RetrieveQueryResult();
Retrieval embeddingRetrieval = new Retrieval();
embeddingRetrieval.setId(id);
embeddingRetrieval.setPresetId(id);
embeddingRetrieval.setDistance("0.15");
embeddingRetrieval.setDistance(0.15);
embeddingResp.setQuery(text);
embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval));
when(pluginManager.recognize(text)).thenReturn(embeddingResp);
@@ -34,13 +34,11 @@ public class MockConfiguration {
when(embeddingConfig.getUrl()).thenReturn("test");
}
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path,
ResponseEntity<String> responseEntity) {
when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity);
}
public static void mockAgent(AgentService agentService) {
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
}
public static void embeddingUtils(EmbeddingUtils embeddingUtils, ResponseEntity<String> responseEntity) {
when(embeddingUtils.doRequest(anyObject(), anyObject(), anyObject())).thenReturn(responseEntity);
}
}