[improvement][Headless] Simplify the QueryService interface, optimize Query permissions, and add integration testing. (#687)

This commit is contained in:
lexluo09
2024-01-24 17:33:12 +08:00
committed by GitHub
parent 48fb01f6bc
commit 922201c181
29 changed files with 529 additions and 629 deletions

View File

@@ -1,10 +0,0 @@
package com.tencent.supersonic.benchmark;
import org.junit.Test;
public class CSpider {
@Test
public void case1(){
}
}

View File

@@ -1,7 +1,8 @@
package com.tencent.supersonic.integration;
package com.tencent.supersonic.chat.integration;
import static org.junit.Assert.assertEquals;
import com.tencent.supersonic.chat.integration.util.DataUtils;
import com.tencent.supersonic.StandaloneLauncher;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
@@ -15,7 +16,6 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.ConfigService;
import com.tencent.supersonic.chat.server.service.QueryService;
import com.tencent.supersonic.util.DataUtils;
import java.time.LocalDate;
import java.util.Set;
import java.util.stream.Collectors;
@@ -30,7 +30,7 @@ import org.springframework.test.context.junit4.SpringRunner;
@RunWith(SpringRunner.class)
@SpringBootTest(classes = StandaloneLauncher.class)
@ActiveProfiles("local")
public class BaseQueryTest {
public class BaseTest {
protected final int unit = 7;
protected final String startDay = LocalDate.now().plusDays(-unit).toString();

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.integration;
package com.tencent.supersonic.chat.integration;
import com.tencent.supersonic.StandaloneLauncher;
import com.tencent.supersonic.chat.core.query.llm.analytics.LLMAnswerResp;

View File

@@ -1,4 +1,7 @@
package com.tencent.supersonic.integration;
package com.tencent.supersonic.chat.integration;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
@@ -8,27 +11,23 @@ import com.tencent.supersonic.chat.core.query.rule.metric.MetricFilterQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricGroupByQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricTopNQuery;
import com.tencent.supersonic.chat.integration.util.DataUtils;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.util.DataUtils;
import org.junit.Assert;
import org.junit.Test;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
import org.junit.Assert;
import org.junit.Test;
public class MetricQueryTest extends BaseQueryTest {
public class MetricTest extends BaseTest {
@Test
public void queryTest_metric_filter() throws Exception {
public void testMetricFilter() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);
@@ -52,7 +51,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_metric_filter_with_agent() {
public void testMetricFilterWithAgent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockMetricAgent(agentService);
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getMetricAgent().getId());
@@ -63,7 +62,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_metric_domain() throws Exception {
public void testMetricDomain() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitNewChat("超音数的访问次数", DataUtils.metricAgentId);
@@ -83,7 +82,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_metric_model_with_agent() {
public void testMetricModelWithAgent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
MockConfiguration.mockMetricAgent(agentService);
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getMetricAgent().getId());
@@ -93,7 +92,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_metric_groupby() throws Exception {
public void testMetricGroupBy() throws Exception {
QueryResult actualResult = submitNewChat("超音数各部门的访问次数", DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult();
@@ -114,7 +113,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_metric_filter_compare() throws Exception {
public void testMetricFilterCompare() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.metricAgentId);
@@ -139,7 +138,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_metric_topn() throws Exception {
public void testMetricTopN() throws Exception {
QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult();
@@ -161,7 +160,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_metric_groupby_sum() throws Exception {
public void testMetricGroupBySum() throws Exception {
QueryResult actualResult = submitNewChat("超音数各部门的访问次数总和", DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -181,7 +180,7 @@ public class MetricQueryTest extends BaseQueryTest {
}
@Test
public void queryTest_metric_filter_time() throws Exception {
public void testMetricFilterTime() throws Exception {
MockConfiguration.mockMetricAgent(agentService);
DateFormat format = new SimpleDateFormat("yyyy-mm-dd");
DateFormat textFormat = new SimpleDateFormat("yyyy年mm月dd日");

View File

@@ -1,15 +1,15 @@
package com.tencent.supersonic.integration;
package com.tencent.supersonic.chat.integration;
import static org.mockito.Mockito.when;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.integration.util.DataUtils;
import com.tencent.supersonic.chat.core.plugin.PluginManager;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Configuration;

View File

@@ -1,7 +1,8 @@
package com.tencent.supersonic.integration;
package com.tencent.supersonic.chat.integration;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import com.tencent.supersonic.chat.integration.util.DataUtils;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricFilterQuery;
@@ -9,13 +10,12 @@ import com.tencent.supersonic.chat.core.query.rule.metric.MetricGroupByQuery;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.util.DataUtils;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import org.junit.Test;
import org.junit.jupiter.api.Order;
public class MultiTurnsTest extends BaseQueryTest {
public class MultiTurnsTest extends BaseTest {
@Test
@Order(1)

View File

@@ -1,5 +1,8 @@
package com.tencent.supersonic.integration;
package com.tencent.supersonic.chat.integration;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import com.tencent.supersonic.chat.integration.util.DataUtils;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
@@ -10,15 +13,11 @@ import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.util.DataUtils;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
public class TagQueryTest extends BaseQueryTest {
public class TagTest extends BaseTest {
@Test
public void queryTest_metric_tag_query() throws Exception {

View File

@@ -1,5 +1,9 @@
package com.tencent.supersonic.integration.mapper;
package com.tencent.supersonic.chat.integration.mapper;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import com.tencent.supersonic.chat.integration.BaseTest;
import com.tencent.supersonic.chat.integration.util.DataUtils;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
@@ -9,13 +13,9 @@ import com.tencent.supersonic.chat.core.query.rule.metric.MetricTagQuery;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.integration.BaseQueryTest;
import com.tencent.supersonic.util.DataUtils;
import org.junit.Test;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
public class MapperTest extends BaseQueryTest {
public class MapperTest extends BaseTest {
@Test
public void hanlp() throws Exception {

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.integration.model;
package com.tencent.supersonic.chat.integration.model;
import com.google.common.collect.Lists;
import com.tencent.supersonic.StandaloneLauncher;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.integration.plugin;
package com.tencent.supersonic.chat.integration.plugin;
import com.tencent.supersonic.StandaloneLauncher;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.integration.plugin;
package com.tencent.supersonic.chat.integration.plugin;
import com.tencent.supersonic.chat.integration.util.DataUtils;
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
@@ -10,8 +11,7 @@ import com.tencent.supersonic.chat.core.plugin.PluginManager;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.QueryService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.integration.MockConfiguration;
import com.tencent.supersonic.util.DataUtils;
import com.tencent.supersonic.chat.integration.MockConfiguration;
import org.junit.Assert;
import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.util;
package com.tencent.supersonic.chat.integration.util;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;

View File

@@ -0,0 +1,43 @@
package com.tencent.supersonic.headless.integration;
import com.tencent.supersonic.StandaloneLauncher;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.service.QueryService;
import java.util.HashSet;
import java.util.Set;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.test.context.junit4.SpringRunner;
@RunWith(SpringRunner.class)
@SpringBootTest(classes = StandaloneLauncher.class)
@ActiveProfiles("local")
public class BaseTest {
@Autowired
private QueryService queryService;
protected SemanticQueryResp queryBySql(String sql) throws Exception {
return queryBySql(sql, User.getFakeUser());
}
protected SemanticQueryResp queryBySql(String sql, User user) throws Exception {
return queryService.queryByReq(buildQuerySqlReq(sql), user);
}
protected QuerySqlReq buildQuerySqlReq(String sql) {
QuerySqlReq querySqlCmd = new QuerySqlReq();
querySqlCmd.setSql(sql);
Set<Long> modelIds = new HashSet<>();
modelIds.add(1L);
modelIds.add(2L);
modelIds.add(3L);
querySqlCmd.setModelIds(modelIds);
return querySqlCmd;
}
}

View File

@@ -0,0 +1,59 @@
package com.tencent.supersonic.headless.integration;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.exception.InvalidPermissionException;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import org.junit.Test;
public class QueryBySqlTest extends BaseTest {
@Test
public void testSumQuery() throws Exception {
SemanticQueryResp semanticQueryResp = queryBySql("SELECT SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 ");
assertEquals(1, semanticQueryResp.getColumns().size());
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
assertEquals("访问次数", queryColumn.getName());
assertEquals(1, semanticQueryResp.getResultList().size());
}
@Test
public void testGroupByQuery() throws Exception {
SemanticQueryResp result = queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
assertEquals(2, result.getColumns().size());
QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1);
assertEquals("部门", firstColumn.getName());
assertEquals("访问次数", secondColumn.getName());
assertEquals(4, result.getResultList().size());
}
@Test
public void testCacheQuery() throws Exception {
SemanticQueryResp result1 = queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
SemanticQueryResp result2 = queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
assertEquals(result1, result2);
}
@Test
public void testBizNameQuery() throws Exception {
SemanticQueryResp result1 = queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'");
SemanticQueryResp result2 = queryBySql("SELECT SUM(访问次数) FROM 超音数PVUV统计 WHERE 部门 ='HR'");
assertEquals(1, result1.getColumns().size());
assertEquals(1, result2.getColumns().size());
assertEquals(result1.getColumns().get(0), result2.getColumns().get(0));
assertEquals(result1.getResultList(), result2.getResultList());
}
@Test
public void testAuthorization() throws Exception {
User alice = new User(2L, "alice", "alice", "alice@email", 0);
assertThrows(InvalidPermissionException.class,
() -> queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'", alice));
}
}