mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-20 06:34:55 +00:00
(improvement)(Chat) Fix start failed and integration tests failed (#1039)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -13,7 +13,6 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.mock.mockito.MockBean;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.Objects;
|
||||
@@ -33,7 +32,7 @@ public class BaseTest extends BaseApplication {
|
||||
protected ChatService chatService;
|
||||
@Autowired
|
||||
protected ConfigService configService;
|
||||
@MockBean
|
||||
@Autowired
|
||||
protected AgentService agentService;
|
||||
|
||||
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception {
|
||||
|
||||
@@ -29,7 +29,6 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricFilter() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
@@ -55,7 +54,6 @@ public class MetricTest extends BaseTest {
|
||||
@Test
|
||||
public void testMetricFilterWithAgent() {
|
||||
//agent only support METRIC_ENTITY, METRIC_FILTER
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getMetricAgent().getId());
|
||||
Assert.assertNotNull(parseResp.getSelectedParses());
|
||||
List<String> queryModes = parseResp.getSelectedParses().stream()
|
||||
@@ -65,7 +63,6 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricDomain() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitNewChat("超音数的访问次数", DataUtils.metricAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
@@ -87,7 +84,6 @@ public class MetricTest extends BaseTest {
|
||||
@Test
|
||||
public void testMetricModelWithAgent() {
|
||||
//agent only support METRIC_ENTITY, METRIC_FILTER
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getMetricAgent().getId());
|
||||
List<String> queryModes = parseResp.getSelectedParses().stream()
|
||||
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
|
||||
@@ -118,7 +114,6 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricFilterCompare() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.metricAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
@@ -187,7 +182,6 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricFilterTime() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
DateFormat format = new SimpleDateFormat("yyyy-mm-dd");
|
||||
DateFormat textFormat = new SimpleDateFormat("yyyy年mm月dd日");
|
||||
String dateStr = textFormat.format(format.parse(startDay));
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
package com.tencent.supersonic.chat;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@Configuration
|
||||
@Slf4j
|
||||
public class MockConfiguration {
|
||||
|
||||
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
|
||||
RetrieveQueryResult embeddingResp = new RetrieveQueryResult();
|
||||
Retrieval embeddingRetrieval = new Retrieval();
|
||||
embeddingRetrieval.setId(id);
|
||||
embeddingRetrieval.setDistance(0.15);
|
||||
embeddingResp.setQuery(text);
|
||||
embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval));
|
||||
when(pluginManager.recognize(text)).thenReturn(embeddingResp);
|
||||
}
|
||||
|
||||
public static void mockEmbeddingUrl(EmbeddingConfig embeddingConfig) {
|
||||
when(embeddingConfig.getUrl()).thenReturn("test");
|
||||
}
|
||||
|
||||
public static void mockMetricAgent(AgentService agentService) {
|
||||
when(agentService.getAgent(1)).thenReturn(DataUtils.getMetricAgent());
|
||||
}
|
||||
|
||||
public static void mockTagAgent(AgentService agentService) {
|
||||
when(agentService.getAgent(2)).thenReturn(DataUtils.getTagAgent());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -17,7 +17,6 @@ public class MultiTurnsTest extends BaseTest {
|
||||
@Test
|
||||
@Order(1)
|
||||
public void queryTest_01() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitMultiTurnChat("alice的访问次数",
|
||||
DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID);
|
||||
|
||||
@@ -44,7 +43,6 @@ public class MultiTurnsTest extends BaseTest {
|
||||
@Test
|
||||
@Order(2)
|
||||
public void queryTest_02() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitMultiTurnChat("停留时长呢", DataUtils.metricAgentId,
|
||||
DataUtils.MULTI_TURNS_CHAT_ID);
|
||||
|
||||
@@ -69,7 +67,6 @@ public class MultiTurnsTest extends BaseTest {
|
||||
@Test
|
||||
@Order(3)
|
||||
public void queryTest_03() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitMultiTurnChat("lucy的如何", DataUtils.metricAgentId,
|
||||
DataUtils.MULTI_TURNS_CHAT_ID);
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ public class TagTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void queryTest_tag_list_filter() throws Exception {
|
||||
MockConfiguration.mockTagAgent(agentService);
|
||||
QueryResult actualResult = submitNewChat("爱情、流行类型的艺人", DataUtils.tagAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
|
||||
@@ -24,7 +24,7 @@ public class ModelSchemaTest extends BaseTest {
|
||||
fieldRemovedReq.setModelId(2L);
|
||||
fieldRemovedReq.setFields(Lists.newArrayList("pv"));
|
||||
UnAvailableItemResp unAvailableItemResp = modelService.getUnAvailableItem(fieldRemovedReq);
|
||||
List<Long> expectedUnAvailableMetricId = Lists.newArrayList(1L, 3L);
|
||||
List<Long> expectedUnAvailableMetricId = Lists.newArrayList(1L, 4L);
|
||||
List<Long> actualUnAvailableMetricId = unAvailableItemResp.getMetricResps()
|
||||
.stream().map(MetricResp::getId).sorted(Comparator.naturalOrder()).collect(Collectors.toList());
|
||||
Assertions.assertEquals(expectedUnAvailableMetricId, actualUnAvailableMetricId);
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
package com.tencent.supersonic.headless;
|
||||
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import java.util.Arrays;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
public class QueryByMetricTest extends BaseTest {
|
||||
|
||||
@Autowired
|
||||
@@ -58,7 +59,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
public void testWithMetricAndDimensionIds() throws Exception {
|
||||
QueryMetricReq queryMetricReq = new QueryMetricReq();
|
||||
queryMetricReq.setDomainId(1L);
|
||||
queryMetricReq.setMetricIds(Arrays.asList(1L, 4L));
|
||||
queryMetricReq.setMetricIds(Arrays.asList(1L, 3L));
|
||||
queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L));
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
|
||||
@@ -66,6 +66,10 @@ s2:
|
||||
names: S2VisitsDemo,S2ArtistDemo
|
||||
enableLLM: true
|
||||
|
||||
schema:
|
||||
cache:
|
||||
enable: false
|
||||
|
||||
langchain4j:
|
||||
#1.chat-model
|
||||
chat-model:
|
||||
|
||||
Reference in New Issue
Block a user