(improvement)(Chat) Fix start failed and integration tests failed (#1039)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-05-28 17:59:51 +08:00
committed by GitHub
parent c51c278f33
commit 5064708c56
11 changed files with 25 additions and 72 deletions

View File

@@ -13,7 +13,6 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.util.DataUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.mock.mockito.MockBean;
import java.time.LocalDate;
import java.util.Objects;
@@ -33,7 +32,7 @@ public class BaseTest extends BaseApplication {
protected ChatService chatService;
@Autowired
protected ConfigService configService;
@MockBean
@Autowired
protected AgentService agentService;
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception {

View File

@@ -29,7 +29,6 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricFilter() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult();
@@ -55,7 +54,6 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricFilterWithAgent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockMetricAgent(agentService);
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getMetricAgent().getId());
Assert.assertNotNull(parseResp.getSelectedParses());
List<String> queryModes = parseResp.getSelectedParses().stream()
@@ -65,7 +63,6 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricDomain() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitNewChat("超音数的访问次数", DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult();
@@ -87,7 +84,6 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricModelWithAgent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockMetricAgent(agentService);
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getMetricAgent().getId());
List<String> queryModes = parseResp.getSelectedParses().stream()
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
@@ -118,7 +114,6 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricFilterCompare() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult();
@@ -187,7 +182,6 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricFilterTime() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
DateFormat format = new SimpleDateFormat("yyyy-mm-dd");
DateFormat textFormat = new SimpleDateFormat("yyyy年mm月dd日");
String dateStr = textFormat.format(format.parse(startDay));

View File

@@ -1,42 +0,0 @@
package com.tencent.supersonic.chat;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;
import static org.mockito.Mockito.when;
@Configuration
@Slf4j
public class MockConfiguration {
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
RetrieveQueryResult embeddingResp = new RetrieveQueryResult();
Retrieval embeddingRetrieval = new Retrieval();
embeddingRetrieval.setId(id);
embeddingRetrieval.setDistance(0.15);
embeddingResp.setQuery(text);
embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval));
when(pluginManager.recognize(text)).thenReturn(embeddingResp);
}
public static void mockEmbeddingUrl(EmbeddingConfig embeddingConfig) {
when(embeddingConfig.getUrl()).thenReturn("test");
}
public static void mockMetricAgent(AgentService agentService) {
when(agentService.getAgent(1)).thenReturn(DataUtils.getMetricAgent());
}
public static void mockTagAgent(AgentService agentService) {
when(agentService.getAgent(2)).thenReturn(DataUtils.getTagAgent());
}
}

View File

@@ -17,7 +17,6 @@ public class MultiTurnsTest extends BaseTest {
@Test
@Order(1)
public void queryTest_01() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitMultiTurnChat("alice的访问次数",
DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID);
@@ -44,7 +43,6 @@ public class MultiTurnsTest extends BaseTest {
@Test
@Order(2)
public void queryTest_02() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitMultiTurnChat("停留时长呢", DataUtils.metricAgentId,
DataUtils.MULTI_TURNS_CHAT_ID);
@@ -69,7 +67,6 @@ public class MultiTurnsTest extends BaseTest {
@Test
@Order(3)
public void queryTest_03() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitMultiTurnChat("lucy的如何", DataUtils.metricAgentId,
DataUtils.MULTI_TURNS_CHAT_ID);

View File

@@ -16,7 +16,6 @@ public class TagTest extends BaseTest {
@Test
public void queryTest_tag_list_filter() throws Exception {
MockConfiguration.mockTagAgent(agentService);
QueryResult actualResult = submitNewChat("爱情、流行类型的艺人", DataUtils.tagAgentId);
QueryResult expectedResult = new QueryResult();

View File

@@ -24,7 +24,7 @@ public class ModelSchemaTest extends BaseTest {
fieldRemovedReq.setModelId(2L);
fieldRemovedReq.setFields(Lists.newArrayList("pv"));
UnAvailableItemResp unAvailableItemResp = modelService.getUnAvailableItem(fieldRemovedReq);
List<Long> expectedUnAvailableMetricId = Lists.newArrayList(1L, 3L);
List<Long> expectedUnAvailableMetricId = Lists.newArrayList(1L, 4L);
List<Long> actualUnAvailableMetricId = unAvailableItemResp.getMetricResps()
.stream().map(MetricResp::getId).sorted(Comparator.naturalOrder()).collect(Collectors.toList());
Assertions.assertEquals(expectedUnAvailableMetricId, actualUnAvailableMetricId);

View File

@@ -1,17 +1,18 @@
package com.tencent.supersonic.headless;
import static org.junit.Assert.assertThrows;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.request.QueryMetricReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.service.MetricService;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.Arrays;
import static org.junit.Assert.assertThrows;
public class QueryByMetricTest extends BaseTest {
@Autowired
@@ -58,7 +59,7 @@ public class QueryByMetricTest extends BaseTest {
public void testWithMetricAndDimensionIds() throws Exception {
QueryMetricReq queryMetricReq = new QueryMetricReq();
queryMetricReq.setDomainId(1L);
queryMetricReq.setMetricIds(Arrays.asList(1L, 4L));
queryMetricReq.setMetricIds(Arrays.asList(1L, 3L));
queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L));
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
Assert.assertNotNull(queryResp.getResultList());

View File

@@ -66,6 +66,10 @@ s2:
names: S2VisitsDemo,S2ArtistDemo
enableLLM: true
schema:
cache:
enable: false
langchain4j:
#1.chat-model
chat-model: