mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-20 06:34:55 +00:00
[improvement][Headless] Simplify the QueryService interface, optimize Query permissions, and add integration testing. (#687)
This commit is contained in:
@@ -1,10 +0,0 @@
|
||||
package com.tencent.supersonic.benchmark;
|
||||
|
||||
import org.junit.Test;
|
||||
|
||||
public class CSpider {
|
||||
@Test
|
||||
public void case1(){
|
||||
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,8 @@
|
||||
package com.tencent.supersonic.integration;
|
||||
package com.tencent.supersonic.chat.integration;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
import com.tencent.supersonic.chat.integration.util.DataUtils;
|
||||
import com.tencent.supersonic.StandaloneLauncher;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
@@ -15,7 +16,6 @@ import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.ConfigService;
|
||||
import com.tencent.supersonic.chat.server.service.QueryService;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import java.time.LocalDate;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -30,7 +30,7 @@ import org.springframework.test.context.junit4.SpringRunner;
|
||||
@RunWith(SpringRunner.class)
|
||||
@SpringBootTest(classes = StandaloneLauncher.class)
|
||||
@ActiveProfiles("local")
|
||||
public class BaseQueryTest {
|
||||
public class BaseTest {
|
||||
|
||||
protected final int unit = 7;
|
||||
protected final String startDay = LocalDate.now().plusDays(-unit).toString();
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.integration;
|
||||
package com.tencent.supersonic.chat.integration;
|
||||
|
||||
import com.tencent.supersonic.StandaloneLauncher;
|
||||
import com.tencent.supersonic.chat.core.query.llm.analytics.LLMAnswerResp;
|
||||
@@ -1,4 +1,7 @@
|
||||
package com.tencent.supersonic.integration;
|
||||
package com.tencent.supersonic.chat.integration;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
@@ -8,27 +11,23 @@ import com.tencent.supersonic.chat.core.query.rule.metric.MetricFilterQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricGroupByQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricTopNQuery;
|
||||
import com.tencent.supersonic.chat.integration.util.DataUtils;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.text.DateFormat;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
|
||||
public class MetricQueryTest extends BaseQueryTest {
|
||||
public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void queryTest_metric_filter() throws Exception {
|
||||
public void testMetricFilter() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);
|
||||
|
||||
@@ -52,7 +51,7 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_metric_filter_with_agent() {
|
||||
public void testMetricFilterWithAgent() {
|
||||
//agent only support METRIC_ENTITY, METRIC_FILTER
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getMetricAgent().getId());
|
||||
@@ -63,7 +62,7 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_metric_domain() throws Exception {
|
||||
public void testMetricDomain() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitNewChat("超音数的访问次数", DataUtils.metricAgentId);
|
||||
|
||||
@@ -83,7 +82,7 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_metric_model_with_agent() {
|
||||
public void testMetricModelWithAgent() {
|
||||
//agent only support METRIC_ENTITY, METRIC_FILTER
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getMetricAgent().getId());
|
||||
@@ -93,7 +92,7 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_metric_groupby() throws Exception {
|
||||
public void testMetricGroupBy() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("超音数各部门的访问次数", DataUtils.metricAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
@@ -114,7 +113,7 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_metric_filter_compare() throws Exception {
|
||||
public void testMetricFilterCompare() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.metricAgentId);
|
||||
|
||||
@@ -139,7 +138,7 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_metric_topn() throws Exception {
|
||||
public void testMetricTopN() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.metricAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
@@ -161,7 +160,7 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_metric_groupby_sum() throws Exception {
|
||||
public void testMetricGroupBySum() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("超音数各部门的访问次数总和", DataUtils.metricAgentId);
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
@@ -181,7 +180,7 @@ public class MetricQueryTest extends BaseQueryTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void queryTest_metric_filter_time() throws Exception {
|
||||
public void testMetricFilterTime() throws Exception {
|
||||
MockConfiguration.mockMetricAgent(agentService);
|
||||
DateFormat format = new SimpleDateFormat("yyyy-mm-dd");
|
||||
DateFormat textFormat = new SimpleDateFormat("yyyy年mm月dd日");
|
||||
@@ -1,15 +1,15 @@
|
||||
package com.tencent.supersonic.integration;
|
||||
package com.tencent.supersonic.chat.integration;
|
||||
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.integration.util.DataUtils;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package com.tencent.supersonic.integration;
|
||||
package com.tencent.supersonic.chat.integration;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||
|
||||
import com.tencent.supersonic.chat.integration.util.DataUtils;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricFilterQuery;
|
||||
@@ -9,13 +10,12 @@ import com.tencent.supersonic.chat.core.query.rule.metric.MetricGroupByQuery;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import java.text.DateFormat;
|
||||
import java.text.SimpleDateFormat;
|
||||
import org.junit.Test;
|
||||
import org.junit.jupiter.api.Order;
|
||||
|
||||
public class MultiTurnsTest extends BaseQueryTest {
|
||||
public class MultiTurnsTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
@@ -1,5 +1,8 @@
|
||||
package com.tencent.supersonic.integration;
|
||||
package com.tencent.supersonic.chat.integration;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||
|
||||
import com.tencent.supersonic.chat.integration.util.DataUtils;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
@@ -10,15 +13,11 @@ import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.junit.Test;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||
|
||||
public class TagQueryTest extends BaseQueryTest {
|
||||
public class TagTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void queryTest_metric_tag_query() throws Exception {
|
||||
@@ -1,5 +1,9 @@
|
||||
package com.tencent.supersonic.integration.mapper;
|
||||
package com.tencent.supersonic.chat.integration.mapper;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||
|
||||
import com.tencent.supersonic.chat.integration.BaseTest;
|
||||
import com.tencent.supersonic.chat.integration.util.DataUtils;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
@@ -9,13 +13,9 @@ import com.tencent.supersonic.chat.core.query.rule.metric.MetricTagQuery;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.integration.BaseQueryTest;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.Test;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||
|
||||
public class MapperTest extends BaseQueryTest {
|
||||
public class MapperTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void hanlp() throws Exception {
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.integration.model;
|
||||
package com.tencent.supersonic.chat.integration.model;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.StandaloneLauncher;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.integration.plugin;
|
||||
package com.tencent.supersonic.chat.integration.plugin;
|
||||
|
||||
import com.tencent.supersonic.StandaloneLauncher;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.integration.plugin;
|
||||
package com.tencent.supersonic.chat.integration.plugin;
|
||||
|
||||
import com.tencent.supersonic.chat.integration.util.DataUtils;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
@@ -10,8 +11,7 @@ import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.QueryService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.integration.MockConfiguration;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import com.tencent.supersonic.chat.integration.MockConfiguration;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.util;
|
||||
package com.tencent.supersonic.chat.integration.util;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.tencent.supersonic.headless.integration;
|
||||
|
||||
import com.tencent.supersonic.StandaloneLauncher;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.server.service.QueryService;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.test.context.ActiveProfiles;
|
||||
import org.springframework.test.context.junit4.SpringRunner;
|
||||
|
||||
@RunWith(SpringRunner.class)
|
||||
@SpringBootTest(classes = StandaloneLauncher.class)
|
||||
@ActiveProfiles("local")
|
||||
public class BaseTest {
|
||||
|
||||
@Autowired
|
||||
private QueryService queryService;
|
||||
|
||||
protected SemanticQueryResp queryBySql(String sql) throws Exception {
|
||||
return queryBySql(sql, User.getFakeUser());
|
||||
}
|
||||
|
||||
protected SemanticQueryResp queryBySql(String sql, User user) throws Exception {
|
||||
return queryService.queryByReq(buildQuerySqlReq(sql), user);
|
||||
}
|
||||
|
||||
protected QuerySqlReq buildQuerySqlReq(String sql) {
|
||||
QuerySqlReq querySqlCmd = new QuerySqlReq();
|
||||
querySqlCmd.setSql(sql);
|
||||
Set<Long> modelIds = new HashSet<>();
|
||||
modelIds.add(1L);
|
||||
modelIds.add(2L);
|
||||
modelIds.add(3L);
|
||||
querySqlCmd.setModelIds(modelIds);
|
||||
return querySqlCmd;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package com.tencent.supersonic.headless.integration;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertThrows;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidPermissionException;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import org.junit.Test;
|
||||
|
||||
public class QueryBySqlTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testSumQuery() throws Exception {
|
||||
SemanticQueryResp semanticQueryResp = queryBySql("SELECT SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 ");
|
||||
|
||||
assertEquals(1, semanticQueryResp.getColumns().size());
|
||||
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
|
||||
assertEquals("访问次数", queryColumn.getName());
|
||||
assertEquals(1, semanticQueryResp.getResultList().size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testGroupByQuery() throws Exception {
|
||||
SemanticQueryResp result = queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
|
||||
assertEquals(2, result.getColumns().size());
|
||||
QueryColumn firstColumn = result.getColumns().get(0);
|
||||
QueryColumn secondColumn = result.getColumns().get(1);
|
||||
assertEquals("部门", firstColumn.getName());
|
||||
assertEquals("访问次数", secondColumn.getName());
|
||||
assertEquals(4, result.getResultList().size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCacheQuery() throws Exception {
|
||||
SemanticQueryResp result1 = queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
|
||||
SemanticQueryResp result2 = queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
|
||||
assertEquals(result1, result2);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBizNameQuery() throws Exception {
|
||||
SemanticQueryResp result1 = queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'");
|
||||
SemanticQueryResp result2 = queryBySql("SELECT SUM(访问次数) FROM 超音数PVUV统计 WHERE 部门 ='HR'");
|
||||
assertEquals(1, result1.getColumns().size());
|
||||
assertEquals(1, result2.getColumns().size());
|
||||
assertEquals(result1.getColumns().get(0), result2.getColumns().get(0));
|
||||
assertEquals(result1.getResultList(), result2.getResultList());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAuthorization() throws Exception {
|
||||
User alice = new User(2L, "alice", "alice", "alice@email", 0);
|
||||
assertThrows(InvalidPermissionException.class,
|
||||
() -> queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'", alice));
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user