mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-20 06:34:55 +00:00
[improvement](chat) Unified vector-related interfaces to go through EmbeddingUtils. (#476)
This commit is contained in:
@@ -2,21 +2,12 @@ package com.tencent.supersonic.integration;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.StandaloneLauncher;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.query.llm.analytics.LLMAnswerResp;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.chat.service.QueryService;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.Assert;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.boot.test.mock.mockito.MockBean;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
@@ -31,37 +22,21 @@ public class MetricInterpretTest {
|
||||
@MockBean
|
||||
private AgentService agentService;
|
||||
|
||||
@MockBean
|
||||
private PluginManager pluginManager;
|
||||
|
||||
@MockBean
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
@Autowired
|
||||
@Qualifier("chatQueryService")
|
||||
private QueryService queryService;
|
||||
@MockBean
|
||||
private EmbeddingUtils embeddingUtils;
|
||||
|
||||
@Test
|
||||
public void testMetricInterpret() throws Exception {
|
||||
MockConfiguration.mockAgent(agentService);
|
||||
MockConfiguration.mockEmbeddingUrl(embeddingConfig);
|
||||
|
||||
LLMAnswerResp lLmAnswerResp = new LLMAnswerResp();
|
||||
lLmAnswerResp.setAssistantMessage("alice最近在超音数的访问情况有增多");
|
||||
MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call",
|
||||
ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
|
||||
QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况",
|
||||
DataUtils.getAgent().getId());
|
||||
|
||||
ParseResp parseResp = queryService.performParsing(queryReq);
|
||||
ExecuteQueryReq executeReq = ExecuteQueryReq.builder().user(queryReq.getUser())
|
||||
.chatId(parseResp.getChatId())
|
||||
.queryId(parseResp.getQueryId())
|
||||
.queryText(parseResp.getQueryText())
|
||||
.parseInfo(parseResp.getCandidateParses().get(0))
|
||||
.parseId(parseResp.getCandidateParses().get(0).getId())
|
||||
.build();
|
||||
QueryResult queryResult = queryService.performExecution(executeReq);
|
||||
Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistantMessage());
|
||||
MockConfiguration.embeddingUtils(embeddingUtils, ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp)));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,30 +1,30 @@
|
||||
package com.tencent.supersonic.integration;
|
||||
|
||||
|
||||
import static org.mockito.ArgumentMatchers.anyObject;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrievalResp;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.service.AgentService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import static org.mockito.ArgumentMatchers.eq;
|
||||
import static org.mockito.ArgumentMatchers.notNull;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@Configuration
|
||||
@Slf4j
|
||||
public class MockConfiguration {
|
||||
|
||||
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
|
||||
RecallRetrievalResp embeddingResp = new RecallRetrievalResp();
|
||||
RecallRetrieval embeddingRetrieval = new RecallRetrieval();
|
||||
RetrieveQueryResult embeddingResp = new RetrieveQueryResult();
|
||||
Retrieval embeddingRetrieval = new Retrieval();
|
||||
embeddingRetrieval.setId(id);
|
||||
embeddingRetrieval.setPresetId(id);
|
||||
embeddingRetrieval.setDistance("0.15");
|
||||
embeddingRetrieval.setDistance(0.15);
|
||||
embeddingResp.setQuery(text);
|
||||
embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval));
|
||||
when(pluginManager.recognize(text)).thenReturn(embeddingResp);
|
||||
@@ -34,13 +34,11 @@ public class MockConfiguration {
|
||||
when(embeddingConfig.getUrl()).thenReturn("test");
|
||||
}
|
||||
|
||||
public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path,
|
||||
ResponseEntity<String> responseEntity) {
|
||||
when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity);
|
||||
}
|
||||
|
||||
public static void mockAgent(AgentService agentService) {
|
||||
when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent());
|
||||
}
|
||||
|
||||
public static void embeddingUtils(EmbeddingUtils embeddingUtils, ResponseEntity<String> responseEntity) {
|
||||
when(embeddingUtils.doRequest(anyObject(), anyObject(), anyObject())).thenReturn(responseEntity);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user