(improvement)(Chat) Fix start failed and integration tests failed (#1039)

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2024-05-28 17:59:51 +08:00
committed by GitHub
parent c51c278f33
commit 5064708c56
11 changed files with 25 additions and 72 deletions

View File

@@ -13,6 +13,7 @@ import javax.sql.DataSource;
public class DataBaseConfig {
@Bean("h2")
@Primary
@ConfigurationProperties("spring.datasource")
public DataSource dataSource() {
return new DruidDataSource();

View File

@@ -3,8 +3,6 @@ package com.tencent.supersonic.headless.server.service.impl;
import com.alibaba.fastjson2.JSONObject;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AuthType;
@@ -52,7 +50,6 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -61,9 +58,6 @@ import java.util.stream.Collectors;
public class DataSetServiceImpl
extends ServiceImpl<DataSetDOMapper, DataSetDO> implements DataSetService {
protected final Cache<MetaFilter, List<DataSetResp>> dataSetSchemaCache =
CacheBuilder.newBuilder().expireAfterWrite(30, TimeUnit.SECONDS).build();
@Autowired
private DomainService domainService;
@@ -249,11 +243,7 @@ public class DataSetServiceImpl
MetaFilter metaFilter = new MetaFilter();
metaFilter.setStatus(StatusEnum.ONLINE.getCode());
metaFilter.setIds(dataSetIds);
List<DataSetResp> dataSetList = dataSetSchemaCache.getIfPresent(metaFilter);
if (CollectionUtils.isEmpty(dataSetList)) {
dataSetList = getDataSetList(metaFilter);
dataSetSchemaCache.put(metaFilter, dataSetList);
}
List<DataSetResp> dataSetList = getDataSetList(metaFilter);
return dataSetList.stream()
.flatMap(
dataSetResp -> dataSetResp.getAllModels().stream().map(modelId ->

View File

@@ -53,6 +53,7 @@ import com.tencent.supersonic.headless.server.utils.StatUtils;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@@ -91,6 +92,9 @@ public class SchemaServiceImpl implements SchemaService {
private final TagMetaService tagService;
private final TermService termService;
@Value("${s2.schema.cache.enable:true}")
private boolean schemaCacheEnable;
public SchemaServiceImpl(ModelService modelService,
DimensionService dimensionService,
MetricService metricService,
@@ -112,7 +116,10 @@ public class SchemaServiceImpl implements SchemaService {
@SneakyThrows
@Override
public List<DataSetSchemaResp> fetchDataSetSchema(DataSetFilterReq filter) {
List<DataSetSchemaResp> dataSetList = dataSetSchemaCache.getIfPresent(filter);
List<DataSetSchemaResp> dataSetList = Lists.newArrayList();
if (schemaCacheEnable) {
dataSetList = dataSetSchemaCache.getIfPresent(filter);
}
if (CollectionUtils.isEmpty(dataSetList)) {
dataSetList = buildDataSetSchema(filter);
dataSetSchemaCache.put(filter, dataSetList);
@@ -376,7 +383,10 @@ public class SchemaServiceImpl implements SchemaService {
@Override
public SemanticSchemaResp fetchSemanticSchema(SchemaFilterReq schemaFilterReq) {
SemanticSchemaResp semanticSchemaResp = semanticSchemaCache.getIfPresent(schemaFilterReq);
SemanticSchemaResp semanticSchemaResp = null;
if (schemaCacheEnable) {
semanticSchemaResp = semanticSchemaCache.getIfPresent(schemaFilterReq);
}
if (semanticSchemaResp == null) {
semanticSchemaResp = buildSemanticSchema(schemaFilterReq);
semanticSchemaCache.put(schemaFilterReq, semanticSchemaResp);

View File

@@ -13,7 +13,6 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.util.DataUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.mock.mockito.MockBean;
import java.time.LocalDate;
import java.util.Objects;
@@ -33,7 +32,7 @@ public class BaseTest extends BaseApplication {
protected ChatService chatService;
@Autowired
protected ConfigService configService;
@MockBean
@Autowired
protected AgentService agentService;
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception {

View File

@@ -29,7 +29,6 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricFilter() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult();
@@ -55,7 +54,6 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricFilterWithAgent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockMetricAgent(agentService);
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getMetricAgent().getId());
Assert.assertNotNull(parseResp.getSelectedParses());
List<String> queryModes = parseResp.getSelectedParses().stream()
@@ -65,7 +63,6 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricDomain() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitNewChat("超音数的访问次数", DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult();
@@ -87,7 +84,6 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricModelWithAgent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockMetricAgent(agentService);
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getMetricAgent().getId());
List<String> queryModes = parseResp.getSelectedParses().stream()
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
@@ -118,7 +114,6 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricFilterCompare() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult();
@@ -187,7 +182,6 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricFilterTime() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
DateFormat format = new SimpleDateFormat("yyyy-mm-dd");
DateFormat textFormat = new SimpleDateFormat("yyyy年mm月dd日");
String dateStr = textFormat.format(format.parse(startDay));

View File

@@ -1,42 +0,0 @@
package com.tencent.supersonic.chat;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;
import static org.mockito.Mockito.when;
@Configuration
@Slf4j
public class MockConfiguration {
public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) {
RetrieveQueryResult embeddingResp = new RetrieveQueryResult();
Retrieval embeddingRetrieval = new Retrieval();
embeddingRetrieval.setId(id);
embeddingRetrieval.setDistance(0.15);
embeddingResp.setQuery(text);
embeddingResp.setRetrieval(Lists.newArrayList(embeddingRetrieval));
when(pluginManager.recognize(text)).thenReturn(embeddingResp);
}
public static void mockEmbeddingUrl(EmbeddingConfig embeddingConfig) {
when(embeddingConfig.getUrl()).thenReturn("test");
}
public static void mockMetricAgent(AgentService agentService) {
when(agentService.getAgent(1)).thenReturn(DataUtils.getMetricAgent());
}
public static void mockTagAgent(AgentService agentService) {
when(agentService.getAgent(2)).thenReturn(DataUtils.getTagAgent());
}
}

View File

@@ -17,7 +17,6 @@ public class MultiTurnsTest extends BaseTest {
@Test
@Order(1)
public void queryTest_01() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitMultiTurnChat("alice的访问次数",
DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID);
@@ -44,7 +43,6 @@ public class MultiTurnsTest extends BaseTest {
@Test
@Order(2)
public void queryTest_02() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitMultiTurnChat("停留时长呢", DataUtils.metricAgentId,
DataUtils.MULTI_TURNS_CHAT_ID);
@@ -69,7 +67,6 @@ public class MultiTurnsTest extends BaseTest {
@Test
@Order(3)
public void queryTest_03() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitMultiTurnChat("lucy的如何", DataUtils.metricAgentId,
DataUtils.MULTI_TURNS_CHAT_ID);

View File

@@ -16,7 +16,6 @@ public class TagTest extends BaseTest {
@Test
public void queryTest_tag_list_filter() throws Exception {
MockConfiguration.mockTagAgent(agentService);
QueryResult actualResult = submitNewChat("爱情、流行类型的艺人", DataUtils.tagAgentId);
QueryResult expectedResult = new QueryResult();

View File

@@ -24,7 +24,7 @@ public class ModelSchemaTest extends BaseTest {
fieldRemovedReq.setModelId(2L);
fieldRemovedReq.setFields(Lists.newArrayList("pv"));
UnAvailableItemResp unAvailableItemResp = modelService.getUnAvailableItem(fieldRemovedReq);
List<Long> expectedUnAvailableMetricId = Lists.newArrayList(1L, 3L);
List<Long> expectedUnAvailableMetricId = Lists.newArrayList(1L, 4L);
List<Long> actualUnAvailableMetricId = unAvailableItemResp.getMetricResps()
.stream().map(MetricResp::getId).sorted(Comparator.naturalOrder()).collect(Collectors.toList());
Assertions.assertEquals(expectedUnAvailableMetricId, actualUnAvailableMetricId);

View File

@@ -1,17 +1,18 @@
package com.tencent.supersonic.headless;
import static org.junit.Assert.assertThrows;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.request.QueryMetricReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.service.MetricService;
import java.util.Arrays;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.Arrays;
import static org.junit.Assert.assertThrows;
public class QueryByMetricTest extends BaseTest {
@Autowired
@@ -58,7 +59,7 @@ public class QueryByMetricTest extends BaseTest {
public void testWithMetricAndDimensionIds() throws Exception {
QueryMetricReq queryMetricReq = new QueryMetricReq();
queryMetricReq.setDomainId(1L);
queryMetricReq.setMetricIds(Arrays.asList(1L, 4L));
queryMetricReq.setMetricIds(Arrays.asList(1L, 3L));
queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L));
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
Assert.assertNotNull(queryResp.getResultList());

View File

@@ -66,6 +66,10 @@ s2:
names: S2VisitsDemo,S2ArtistDemo
enableLLM: true
schema:
cache:
enable: false
langchain4j:
#1.chat-model
chat-model: