mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvement)(Chat) Fix start failed and integration tests failed (#1039)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -13,6 +13,7 @@ import javax.sql.DataSource;
|
||||
public class DataBaseConfig {
|
||||
|
||||
@Bean("h2")
|
||||
@Primary
|
||||
@ConfigurationProperties("spring.datasource")
|
||||
public DataSource dataSource() {
|
||||
return new DruidDataSource();
|
||||
|
||||
@@ -3,8 +3,6 @@ package com.tencent.supersonic.headless.server.service.impl;
|
||||
import com.alibaba.fastjson2.JSONObject;
|
||||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||
import com.google.common.cache.Cache;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
@@ -52,7 +50,6 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -61,9 +58,6 @@ import java.util.stream.Collectors;
|
||||
public class DataSetServiceImpl
|
||||
extends ServiceImpl<DataSetDOMapper, DataSetDO> implements DataSetService {
|
||||
|
||||
protected final Cache<MetaFilter, List<DataSetResp>> dataSetSchemaCache =
|
||||
CacheBuilder.newBuilder().expireAfterWrite(30, TimeUnit.SECONDS).build();
|
||||
|
||||
@Autowired
|
||||
private DomainService domainService;
|
||||
|
||||
@@ -249,11 +243,7 @@ public class DataSetServiceImpl
|
||||
MetaFilter metaFilter = new MetaFilter();
|
||||
metaFilter.setStatus(StatusEnum.ONLINE.getCode());
|
||||
metaFilter.setIds(dataSetIds);
|
||||
List<DataSetResp> dataSetList = dataSetSchemaCache.getIfPresent(metaFilter);
|
||||
if (CollectionUtils.isEmpty(dataSetList)) {
|
||||
dataSetList = getDataSetList(metaFilter);
|
||||
dataSetSchemaCache.put(metaFilter, dataSetList);
|
||||
}
|
||||
List<DataSetResp> dataSetList = getDataSetList(metaFilter);
|
||||
return dataSetList.stream()
|
||||
.flatMap(
|
||||
dataSetResp -> dataSetResp.getAllModels().stream().map(modelId ->
|
||||
|
||||
@@ -53,6 +53,7 @@ import com.tencent.supersonic.headless.server.utils.StatUtils;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -91,6 +92,9 @@ public class SchemaServiceImpl implements SchemaService {
|
||||
private final TagMetaService tagService;
|
||||
private final TermService termService;
|
||||
|
||||
@Value("${s2.schema.cache.enable:true}")
|
||||
private boolean schemaCacheEnable;
|
||||
|
||||
public SchemaServiceImpl(ModelService modelService,
|
||||
DimensionService dimensionService,
|
||||
MetricService metricService,
|
||||
@@ -112,7 +116,10 @@ public class SchemaServiceImpl implements SchemaService {
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public List<DataSetSchemaResp> fetchDataSetSchema(DataSetFilterReq filter) {
|
||||
List<DataSetSchemaResp> dataSetList = dataSetSchemaCache.getIfPresent(filter);
|
||||
List<DataSetSchemaResp> dataSetList = Lists.newArrayList();
|
||||
if (schemaCacheEnable) {
|
||||
dataSetList = dataSetSchemaCache.getIfPresent(filter);
|
||||
}
|
||||
if (CollectionUtils.isEmpty(dataSetList)) {
|
||||
dataSetList = buildDataSetSchema(filter);
|
||||
dataSetSchemaCache.put(filter, dataSetList);
|
||||
@@ -376,7 +383,10 @@ public class SchemaServiceImpl implements SchemaService {
|
||||
|
||||
@Override
|
||||
public SemanticSchemaResp fetchSemanticSchema(SchemaFilterReq schemaFilterReq) {
|
||||
SemanticSchemaResp semanticSchemaResp = semanticSchemaCache.getIfPresent(schemaFilterReq);
|
||||
SemanticSchemaResp semanticSchemaResp = null;
|
||||
if (schemaCacheEnable) {
|
||||
semanticSchemaResp = semanticSchemaCache.getIfPresent(schemaFilterReq);
|
||||
}
|
||||
if (semanticSchemaResp == null) {
|
||||
semanticSchemaResp = buildSemanticSchema(schemaFilterReq);
|
||||
semanticSchemaCache.put(schemaFilterReq, semanticSchemaResp);
|
||||
|
||||
@@ -13,7 +13,6 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.mock.mockito.MockBean;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.Objects;
|
||||
@@ -33,7 +32,7 @@ public class BaseTest extends BaseApplication {
|
||||
protected ChatService chatService;
|
||||
@Autowired
|
||||
protected ConfigService configService;
|
||||
@MockBean
|
||||
@Autowired
|
||||
protected AgentService agentService;
|
||||
|
||||
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception {
|
||||
|
||||
@@ -29,7 +29,6 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricFilter() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
@@ -55,7 +54,6 @@ public class MetricTest extends BaseTest {
|
||||
@Test
|
||||
public void testMetricFilterWithAgent() {
|
||||
//agent only support METRIC_ENTITY, METRIC_FILTER
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getMetricAgent().getId());
|
||||
Assert.assertNotNull(parseResp.getSelectedParses());
|
||||
List<String> queryModes = parseResp.getSelectedParses().stream()
|
||||
@@ -65,7 +63,6 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricDomain() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitNewChat("超音数的访问次数", DataUtils.metricAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
@@ -87,7 +84,6 @@ public class MetricTest extends BaseTest {
|
||||
@Test
|
||||
public void testMetricModelWithAgent() {
|
||||
//agent only support METRIC_ENTITY, METRIC_FILTER
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getMetricAgent().getId());
|
||||
List<String> queryModes = parseResp.getSelectedParses().stream()
|
||||
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
|
||||
@@ -118,7 +114,6 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricFilterCompare() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.metricAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
@@ -187,7 +182,6 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricFilterTime() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
DateFormat format = new SimpleDateFormat("yyyy-mm-dd");
|
||||
DateFormat textFormat = new SimpleDateFormat("yyyy年mm月dd日");
|
||||
String dateStr = textFormat.format(format.parse(startDay));
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
package com.tencent.supersonic.chat;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@Configuration
|
||||
@Slf4j
|
||||
public class MockConfiguration {
|
||||
|
||||
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
|
||||
RetrieveQueryResult embeddingResp = new RetrieveQueryResult();
|
||||
Retrieval embeddingRetrieval = new Retrieval();
|
||||
embeddingRetrieval.setId(id);
|
||||
embeddingRetrieval.setDistance(0.15);
|
||||
embeddingResp.setQuery(text);
|
||||
embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval));
|
||||
when(pluginManager.recognize(text)).thenReturn(embeddingResp);
|
||||
}
|
||||
|
||||
public static void mockEmbeddingUrl(EmbeddingConfig embeddingConfig) {
|
||||
when(embeddingConfig.getUrl()).thenReturn("test");
|
||||
}
|
||||
|
||||
public static void mockMetricAgent(AgentService agentService) {
|
||||
when(agentService.getAgent(1)).thenReturn(DataUtils.getMetricAgent());
|
||||
}
|
||||
|
||||
public static void mockTagAgent(AgentService agentService) {
|
||||
when(agentService.getAgent(2)).thenReturn(DataUtils.getTagAgent());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -17,7 +17,6 @@ public class MultiTurnsTest extends BaseTest {
|
||||
@Test
|
||||
@Order(1)
|
||||
public void queryTest_01() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitMultiTurnChat("alice的访问次数",
|
||||
DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID);
|
||||
|
||||
@@ -44,7 +43,6 @@ public class MultiTurnsTest extends BaseTest {
|
||||
@Test
|
||||
@Order(2)
|
||||
public void queryTest_02() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitMultiTurnChat("停留时长呢", DataUtils.metricAgentId,
|
||||
DataUtils.MULTI_TURNS_CHAT_ID);
|
||||
|
||||
@@ -69,7 +67,6 @@ public class MultiTurnsTest extends BaseTest {
|
||||
@Test
|
||||
@Order(3)
|
||||
public void queryTest_03() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitMultiTurnChat("lucy的如何", DataUtils.metricAgentId,
|
||||
DataUtils.MULTI_TURNS_CHAT_ID);
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ public class TagTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void queryTest_tag_list_filter() throws Exception {
|
||||
MockConfiguration.mockTagAgent(agentService);
|
||||
QueryResult actualResult = submitNewChat("爱情、流行类型的艺人", DataUtils.tagAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
|
||||
@@ -24,7 +24,7 @@ public class ModelSchemaTest extends BaseTest {
|
||||
fieldRemovedReq.setModelId(2L);
|
||||
fieldRemovedReq.setFields(Lists.newArrayList("pv"));
|
||||
UnAvailableItemResp unAvailableItemResp = modelService.getUnAvailableItem(fieldRemovedReq);
|
||||
List<Long> expectedUnAvailableMetricId = Lists.newArrayList(1L, 3L);
|
||||
List<Long> expectedUnAvailableMetricId = Lists.newArrayList(1L, 4L);
|
||||
List<Long> actualUnAvailableMetricId = unAvailableItemResp.getMetricResps()
|
||||
.stream().map(MetricResp::getId).sorted(Comparator.naturalOrder()).collect(Collectors.toList());
|
||||
Assertions.assertEquals(expectedUnAvailableMetricId, actualUnAvailableMetricId);
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
package com.tencent.supersonic.headless;
|
||||
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import java.util.Arrays;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
public class QueryByMetricTest extends BaseTest {
|
||||
|
||||
@Autowired
|
||||
@@ -58,7 +59,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
public void testWithMetricAndDimensionIds() throws Exception {
|
||||
QueryMetricReq queryMetricReq = new QueryMetricReq();
|
||||
queryMetricReq.setDomainId(1L);
|
||||
queryMetricReq.setMetricIds(Arrays.asList(1L, 4L));
|
||||
queryMetricReq.setMetricIds(Arrays.asList(1L, 3L));
|
||||
queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L));
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
|
||||
@@ -66,6 +66,10 @@ s2:
|
||||
names: S2VisitsDemo,S2ArtistDemo
|
||||
enableLLM: true
|
||||
|
||||
schema:
|
||||
cache:
|
||||
enable: false
|
||||
|
||||
langchain4j:
|
||||
#1.chat-model
|
||||
chat-model:
|
||||
|
||||
Reference in New Issue
Block a user