From 5064708c56177bbf59379fb61257e06507b45e2f Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Tue, 28 May 2024 17:59:51 +0800 Subject: [PATCH] (improvement)(Chat) Fix start failed and integration tests failed (#1039) Co-authored-by: jolunoluo --- .../common/config/DataBaseConfig.java | 1 + .../service/impl/DataSetServiceImpl.java | 12 +----- .../service/impl/SchemaServiceImpl.java | 14 ++++++- .../com/tencent/supersonic/chat/BaseTest.java | 3 +- .../tencent/supersonic/chat/MetricTest.java | 6 --- .../supersonic/chat/MockConfiguration.java | 42 ------------------- .../supersonic/chat/MultiTurnsTest.java | 3 -- .../com/tencent/supersonic/chat/TagTest.java | 1 - .../supersonic/headless/ModelSchemaTest.java | 2 +- .../headless/QueryByMetricTest.java | 9 ++-- .../src/test/resources/application-local.yaml | 4 ++ 11 files changed, 25 insertions(+), 72 deletions(-) delete mode 100644 launchers/standalone/src/test/java/com/tencent/supersonic/chat/MockConfiguration.java diff --git a/common/src/main/java/com/tencent/supersonic/common/config/DataBaseConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/DataBaseConfig.java index 8fb2293b9..1ff339c37 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/DataBaseConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/DataBaseConfig.java @@ -13,6 +13,7 @@ import javax.sql.DataSource; public class DataBaseConfig { @Bean("h2") + @Primary @ConfigurationProperties("spring.datasource") public DataSource dataSource() { return new DruidDataSource(); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java index ebbb39022..a2467eedf 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java @@ -3,8 +3,6 @@ package com.tencent.supersonic.headless.server.service.impl; import com.alibaba.fastjson2.JSONObject; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; -import com.google.common.cache.Cache; -import com.google.common.cache.CacheBuilder; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.enums.AuthType; @@ -52,7 +50,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; -import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.stream.Collectors; @@ -61,9 +58,6 @@ import java.util.stream.Collectors; public class DataSetServiceImpl extends ServiceImpl implements DataSetService { - protected final Cache> dataSetSchemaCache = - CacheBuilder.newBuilder().expireAfterWrite(30, TimeUnit.SECONDS).build(); - @Autowired private DomainService domainService; @@ -249,11 +243,7 @@ public class DataSetServiceImpl MetaFilter metaFilter = new MetaFilter(); metaFilter.setStatus(StatusEnum.ONLINE.getCode()); metaFilter.setIds(dataSetIds); - List dataSetList = dataSetSchemaCache.getIfPresent(metaFilter); - if (CollectionUtils.isEmpty(dataSetList)) { - dataSetList = getDataSetList(metaFilter); - dataSetSchemaCache.put(metaFilter, dataSetList); - } + List dataSetList = getDataSetList(metaFilter); return dataSetList.stream() .flatMap( dataSetResp -> dataSetResp.getAllModels().stream().map(modelId -> diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java index 937bace3d..696a5c401 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java @@ -53,6 +53,7 @@ import com.tencent.supersonic.headless.server.utils.StatUtils; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; @@ -91,6 +92,9 @@ public class SchemaServiceImpl implements SchemaService { private final TagMetaService tagService; private final TermService termService; + @Value("${s2.schema.cache.enable:true}") + private boolean schemaCacheEnable; + public SchemaServiceImpl(ModelService modelService, DimensionService dimensionService, MetricService metricService, @@ -112,7 +116,10 @@ public class SchemaServiceImpl implements SchemaService { @SneakyThrows @Override public List fetchDataSetSchema(DataSetFilterReq filter) { - List dataSetList = dataSetSchemaCache.getIfPresent(filter); + List dataSetList = Lists.newArrayList(); + if (schemaCacheEnable) { + dataSetList = dataSetSchemaCache.getIfPresent(filter); + } if (CollectionUtils.isEmpty(dataSetList)) { dataSetList = buildDataSetSchema(filter); dataSetSchemaCache.put(filter, dataSetList); @@ -376,7 +383,10 @@ public class SchemaServiceImpl implements SchemaService { @Override public SemanticSchemaResp fetchSemanticSchema(SchemaFilterReq schemaFilterReq) { - SemanticSchemaResp semanticSchemaResp = semanticSchemaCache.getIfPresent(schemaFilterReq); + SemanticSchemaResp semanticSchemaResp = null; + if (schemaCacheEnable) { + semanticSchemaResp = semanticSchemaCache.getIfPresent(schemaFilterReq); + } if (semanticSchemaResp == null) { semanticSchemaResp = buildSemanticSchema(schemaFilterReq); semanticSchemaCache.put(schemaFilterReq, semanticSchemaResp); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java index e9209e987..12c3517a7 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java @@ -13,7 +13,6 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.util.DataUtils; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.test.mock.mockito.MockBean; import java.time.LocalDate; import java.util.Objects; @@ -33,7 +32,7 @@ public class BaseTest extends BaseApplication { protected ChatService chatService; @Autowired protected ConfigService configService; - @MockBean + @Autowired protected AgentService agentService; protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java index 93d7b1923..1e3b5cc59 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java @@ -29,7 +29,6 @@ public class MetricTest extends BaseTest { @Test public void testMetricFilter() throws Exception { - MockConfiguration.mockMetricAgent(agentService); QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId); QueryResult expectedResult = new QueryResult(); @@ -55,7 +54,6 @@ public class MetricTest extends BaseTest { @Test public void testMetricFilterWithAgent() { //agent only support METRIC_ENTITY, METRIC_FILTER - MockConfiguration.mockMetricAgent(agentService); ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getMetricAgent().getId()); Assert.assertNotNull(parseResp.getSelectedParses()); List queryModes = parseResp.getSelectedParses().stream() @@ -65,7 +63,6 @@ public class MetricTest extends BaseTest { @Test public void testMetricDomain() throws Exception { - MockConfiguration.mockMetricAgent(agentService); QueryResult actualResult = submitNewChat("超音数的访问次数", DataUtils.metricAgentId); QueryResult expectedResult = new QueryResult(); @@ -87,7 +84,6 @@ public class MetricTest extends BaseTest { @Test public void testMetricModelWithAgent() { //agent only support METRIC_ENTITY, METRIC_FILTER - MockConfiguration.mockMetricAgent(agentService); ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getMetricAgent().getId()); List queryModes = parseResp.getSelectedParses().stream() .map(SemanticParseInfo::getQueryMode).collect(Collectors.toList()); @@ -118,7 +114,6 @@ public class MetricTest extends BaseTest { @Test public void testMetricFilterCompare() throws Exception { - MockConfiguration.mockMetricAgent(agentService); QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.metricAgentId); QueryResult expectedResult = new QueryResult(); @@ -187,7 +182,6 @@ public class MetricTest extends BaseTest { @Test public void testMetricFilterTime() throws Exception { - MockConfiguration.mockMetricAgent(agentService); DateFormat format = new SimpleDateFormat("yyyy-mm-dd"); DateFormat textFormat = new SimpleDateFormat("yyyy年mm月dd日"); String dateStr = textFormat.format(format.parse(startDay)); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MockConfiguration.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MockConfiguration.java deleted file mode 100644 index bbf2d45ce..000000000 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MockConfiguration.java +++ /dev/null @@ -1,42 +0,0 @@ -package com.tencent.supersonic.chat; - - -import com.google.common.collect.Lists; -import com.tencent.supersonic.chat.server.plugin.PluginManager; -import com.tencent.supersonic.chat.server.service.AgentService; -import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.common.util.embedding.Retrieval; -import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; -import com.tencent.supersonic.util.DataUtils; -import lombok.extern.slf4j.Slf4j; -import org.springframework.context.annotation.Configuration; - -import static org.mockito.Mockito.when; - -@Configuration -@Slf4j -public class MockConfiguration { - - public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) { - RetrieveQueryResult embeddingResp = new RetrieveQueryResult(); - Retrieval embeddingRetrieval = new Retrieval(); - embeddingRetrieval.setId(id); - embeddingRetrieval.setDistance(0.15); - embeddingResp.setQuery(text); - embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval)); - when(pluginManager.recognize(text)).thenReturn(embeddingResp); - } - - public static void mockEmbeddingUrl(EmbeddingConfig embeddingConfig) { - when(embeddingConfig.getUrl()).thenReturn("test"); - } - - public static void mockMetricAgent(AgentService agentService) { - when(agentService.getAgent(1)).thenReturn(DataUtils.getMetricAgent()); - } - - public static void mockTagAgent(AgentService agentService) { - when(agentService.getAgent(2)).thenReturn(DataUtils.getTagAgent()); - } - -} diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java index 648cfba75..444f10cf4 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java @@ -17,7 +17,6 @@ public class MultiTurnsTest extends BaseTest { @Test @Order(1) public void queryTest_01() throws Exception { - MockConfiguration.mockMetricAgent(agentService); QueryResult actualResult = submitMultiTurnChat("alice的访问次数", DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID); @@ -44,7 +43,6 @@ public class MultiTurnsTest extends BaseTest { @Test @Order(2) public void queryTest_02() throws Exception { - MockConfiguration.mockMetricAgent(agentService); QueryResult actualResult = submitMultiTurnChat("停留时长呢", DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID); @@ -69,7 +67,6 @@ public class MultiTurnsTest extends BaseTest { @Test @Order(3) public void queryTest_03() throws Exception { - MockConfiguration.mockMetricAgent(agentService); QueryResult actualResult = submitMultiTurnChat("lucy的如何", DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/TagTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/TagTest.java index 6ee8e495e..3f41da267 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/TagTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/TagTest.java @@ -16,7 +16,6 @@ public class TagTest extends BaseTest { @Test public void queryTest_tag_list_filter() throws Exception { - MockConfiguration.mockTagAgent(agentService); QueryResult actualResult = submitNewChat("爱情、流行类型的艺人", DataUtils.tagAgentId); QueryResult expectedResult = new QueryResult(); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelSchemaTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelSchemaTest.java index 02e6d90a8..f688c7909 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelSchemaTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelSchemaTest.java @@ -24,7 +24,7 @@ public class ModelSchemaTest extends BaseTest { fieldRemovedReq.setModelId(2L); fieldRemovedReq.setFields(Lists.newArrayList("pv")); UnAvailableItemResp unAvailableItemResp = modelService.getUnAvailableItem(fieldRemovedReq); - List expectedUnAvailableMetricId = Lists.newArrayList(1L, 3L); + List expectedUnAvailableMetricId = Lists.newArrayList(1L, 4L); List actualUnAvailableMetricId = unAvailableItemResp.getMetricResps() .stream().map(MetricResp::getId).sorted(Comparator.naturalOrder()).collect(Collectors.toList()); Assertions.assertEquals(expectedUnAvailableMetricId, actualUnAvailableMetricId); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryByMetricTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryByMetricTest.java index 8278d2bba..ff92d1623 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryByMetricTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryByMetricTest.java @@ -1,17 +1,18 @@ package com.tencent.supersonic.headless; -import static org.junit.Assert.assertThrows; - import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.headless.api.pojo.request.QueryMetricReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.server.service.MetricService; -import java.util.Arrays; import org.junit.Assert; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; +import java.util.Arrays; + +import static org.junit.Assert.assertThrows; + public class QueryByMetricTest extends BaseTest { @Autowired @@ -58,7 +59,7 @@ public class QueryByMetricTest extends BaseTest { public void testWithMetricAndDimensionIds() throws Exception { QueryMetricReq queryMetricReq = new QueryMetricReq(); queryMetricReq.setDomainId(1L); - queryMetricReq.setMetricIds(Arrays.asList(1L, 4L)); + queryMetricReq.setMetricIds(Arrays.asList(1L, 3L)); queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L)); SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser()); Assert.assertNotNull(queryResp.getResultList()); diff --git a/launchers/standalone/src/test/resources/application-local.yaml b/launchers/standalone/src/test/resources/application-local.yaml index 5fbd67c00..510c4dc69 100644 --- a/launchers/standalone/src/test/resources/application-local.yaml +++ b/launchers/standalone/src/test/resources/application-local.yaml @@ -66,6 +66,10 @@ s2: names: S2VisitsDemo,S2ArtistDemo enableLLM: true + schema: + cache: + enable: false + langchain4j: #1.chat-model chat-model: