mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-25 00:54:21 +08:00
Refactor translator module (#1932)
* [improvement][Chat] Support agent permission management #1143 * [improvement][chat]Iterate LLM prompts of parsing and correction. * [improvement][headless]Clean code logic of headless core. * (fix) (chat) 记忆管理更新不生效 (#1912) * [improvement][headless-fe] Added null-check conditions to the data formatting function. * [improvement][headless]Clean code logic of headless translator. * [improvement][headless-fe] Added permissions management for agents. * [improvement][headless-fe] Unified the assistant's permission settings interaction to match the system style. * [improvement](Dict)Support returns dict task list of dimensions by page * [improvement][headless-fe] Revised the interaction for semantic modeling routing and implemented the initial version of metric management switching. * [improvement][launcher]Set system property `s2.test` in junit tests in order to facilitate conditional breakpoints. * [improvement][headless] add validateAndQuery interface in SqlQueryApiController * [improvement][launcher]Use API to get element ID avoiding hard-code. * [improvement][launcher]Support DuckDB database and refactor translator code structure. --------- Co-authored-by: lxwcodemonkey <jolunoluo@tencent.com> Co-authored-by: tristanliu <tristanliu@tencent.com> Co-authored-by: daikon12 <1059907724@qq.com> Co-authored-by: lexluo09 <39718951+lexluo09@users.noreply.github.com>
This commit is contained in:
@@ -12,14 +12,15 @@ import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
|
||||
import com.tencent.supersonic.common.service.ChatModelService;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.server.service.SchemaService;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -38,6 +39,8 @@ public class BaseTest extends BaseApplication {
|
||||
protected AgentService agentService;
|
||||
@Autowired
|
||||
protected ChatModelService chatModelService;
|
||||
@Autowired
|
||||
protected SchemaService schemaService;
|
||||
|
||||
@Value("${s2.demo.enableLLM:false}")
|
||||
protected boolean enableLLM;
|
||||
@@ -107,4 +110,10 @@ public class BaseTest extends BaseApplication {
|
||||
|
||||
assertEquals(expectedParseInfo.getDateInfo(), actualParseInfo.getDateInfo());
|
||||
}
|
||||
|
||||
protected SchemaElement getSchemaElementByName(Set<SchemaElement> elementSet, String name) {
|
||||
Optional<SchemaElement> matchElement =
|
||||
elementSet.stream().filter(e -> e.getName().equals(name)).findFirst();
|
||||
return matchElement.orElse(null);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
@@ -12,6 +13,7 @@ import com.tencent.supersonic.headless.chat.query.rule.detail.DetailDimensionQue
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junitpioneer.jupiter.SetSystemProperty;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
|
||||
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
|
||||
@@ -19,8 +21,9 @@ import org.springframework.boot.test.context.SpringBootTest;
|
||||
public class DetailTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
@SetSystemProperty(key = "s2.test", value = "true")
|
||||
public void test_detail_dimension() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.tagAgentId);
|
||||
QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.singerAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
@@ -30,8 +33,11 @@ public class DetailTest extends BaseTest {
|
||||
expectedParseInfo.setQueryType(QueryType.DETAIL);
|
||||
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
|
||||
|
||||
QueryFilter dimensionFilter =
|
||||
DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 8L);
|
||||
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.singerDatasettId);
|
||||
SchemaElement singerElement = getSchemaElementByName(schema.getDimensions(), "歌手名");
|
||||
|
||||
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS,
|
||||
"周杰伦", "歌手名", singerElement.getId());
|
||||
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
|
||||
expectedParseInfo.getDimensions()
|
||||
@@ -43,7 +49,7 @@ public class DetailTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void test_detail_filter() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("国风歌手", DataUtils.tagAgentId);
|
||||
QueryResult actualResult = submitNewChat("国风歌手", DataUtils.singerAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
@@ -53,8 +59,10 @@ public class DetailTest extends BaseTest {
|
||||
expectedParseInfo.setQueryType(QueryType.DETAIL);
|
||||
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
|
||||
|
||||
QueryFilter dimensionFilter =
|
||||
DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, "国风", "流派", 7L);
|
||||
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.singerDatasettId);
|
||||
SchemaElement genreElement = getSchemaElementByName(schema.getDimensions(), "流派");
|
||||
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, "国风",
|
||||
"流派", genreElement.getId());
|
||||
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
expectedParseInfo.getDimensions()
|
||||
.addAll(Lists.newArrayList(SchemaElement.builder().name("歌手名").build()));
|
||||
|
||||
@@ -5,14 +5,18 @@ import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricGroupByQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricTopNQuery;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.jupiter.api.Order;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junitpioneer.jupiter.SetSystemProperty;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
|
||||
import java.text.DateFormat;
|
||||
@@ -28,24 +32,16 @@ import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
|
||||
public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetric() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("超音数 访问次数", DataUtils.metricAgentId);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMetricFilter() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);
|
||||
public void testMetricModel() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("超音数 访问次数", DataUtils.productAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
|
||||
expectedResult.setQueryMode(MetricModelQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||
expectedParseInfo.getDimensionFilters().add(
|
||||
DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
|
||||
|
||||
expectedParseInfo.setDateInfo(
|
||||
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
|
||||
@@ -56,8 +52,35 @@ public class MetricTest extends BaseTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMetricFilter() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.productAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||
|
||||
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId);
|
||||
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
|
||||
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
|
||||
FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId()));
|
||||
|
||||
expectedParseInfo.setDateInfo(
|
||||
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
|
||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
assert actualResult.getQueryResults().size() == 1;
|
||||
}
|
||||
|
||||
@Test
|
||||
@SetSystemProperty(key = "s2.test", value = "true")
|
||||
public void testMetricGroupBy() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数", DataUtils.metricAgentId);
|
||||
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数和停留时长", DataUtils.productAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
@@ -67,6 +90,7 @@ public class MetricTest extends BaseTest {
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("停留时长"));
|
||||
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
|
||||
|
||||
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 7,
|
||||
@@ -79,7 +103,7 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricFilterCompare() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.metricAgentId);
|
||||
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.productAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
@@ -92,8 +116,11 @@ public class MetricTest extends BaseTest {
|
||||
List<String> list = new ArrayList<>();
|
||||
list.add("alice");
|
||||
list.add("lucy");
|
||||
QueryFilter dimensionFilter =
|
||||
DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list, "用户", 2L);
|
||||
|
||||
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId);
|
||||
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
|
||||
QueryFilter dimensionFilter = DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list,
|
||||
"用户", userElement.getId());
|
||||
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
|
||||
expectedParseInfo.setDateInfo(
|
||||
@@ -107,7 +134,7 @@ public class MetricTest extends BaseTest {
|
||||
@Test
|
||||
@Order(3)
|
||||
public void testMetricTopN() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.metricAgentId);
|
||||
QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.productAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
@@ -128,7 +155,7 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricGroupBySum() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数总和", DataUtils.metricAgentId);
|
||||
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数总和", DataUtils.productAgentId);
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
@@ -154,7 +181,7 @@ public class MetricTest extends BaseTest {
|
||||
String dateStr = textFormat.format(format.parse(startDay));
|
||||
|
||||
QueryResult actualResult =
|
||||
submitNewChat(String.format("alice在%s的访问次数", dateStr), DataUtils.metricAgentId);
|
||||
submitNewChat(String.format("alice在%s的访问次数", dateStr), DataUtils.productAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
@@ -163,9 +190,11 @@ public class MetricTest extends BaseTest {
|
||||
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId);
|
||||
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
|
||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||
expectedParseInfo.getDimensionFilters().add(
|
||||
DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
|
||||
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
|
||||
FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId()));
|
||||
|
||||
expectedParseInfo.setDateInfo(
|
||||
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 1, period, startDay, startDay));
|
||||
|
||||
@@ -5,7 +5,10 @@ import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.tencent.supersonic.chat.BaseTest;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.*;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.DatasetTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.common.config.ChatModel;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
@@ -133,11 +136,28 @@ public class Text2SQLEval extends BaseTest {
|
||||
assert result.getTextResult().contains("3");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_detail_query() throws Exception {
|
||||
long start = System.currentTimeMillis();
|
||||
QueryResult result = submitNewChat("特斯拉旗下有哪些品牌", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() >= 1;
|
||||
assert result.getTextResult().contains("Model Y");
|
||||
assert result.getTextResult().contains("Model 3");
|
||||
}
|
||||
|
||||
public Agent getLLMAgent() {
|
||||
Agent agent = new Agent();
|
||||
agent.setName("Agent for Test");
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
toolConfig.getTools().add(getDatasetTool());
|
||||
DatasetTool datasetTool = new DatasetTool();
|
||||
datasetTool.setType(AgentToolType.DATASET);
|
||||
datasetTool.setDataSetIds(Lists.newArrayList(DataUtils.productDatasetId));
|
||||
toolConfig.getTools().add(datasetTool);
|
||||
DatasetTool datasetTool2 = new DatasetTool();
|
||||
datasetTool2.setType(AgentToolType.DATASET);
|
||||
datasetTool2.setDataSetIds(Lists.newArrayList(DataUtils.companyDatasetId));
|
||||
toolConfig.getTools().add(datasetTool2);
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
// create chat model for this evaluation
|
||||
ChatModel chatModel = new ChatModel();
|
||||
@@ -154,11 +174,4 @@ public class Text2SQLEval extends BaseTest {
|
||||
return agent;
|
||||
}
|
||||
|
||||
private static DatasetTool getDatasetTool() {
|
||||
DatasetTool datasetTool = new DatasetTool();
|
||||
datasetTool.setType(AgentToolType.DATASET);
|
||||
datasetTool.setDataSetIds(Lists.newArrayList(1L));
|
||||
|
||||
return datasetTool;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
package com.tencent.supersonic.headless;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.junitpioneer.jupiter.SetSystemProperty;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
|
||||
import java.time.LocalDate;
|
||||
import java.util.Arrays;
|
||||
|
||||
import static org.junit.Assert.assertThrows;
|
||||
@@ -23,16 +27,24 @@ public class QueryByMetricTest extends BaseTest {
|
||||
QueryMetricReq queryMetricReq = new QueryMetricReq();
|
||||
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
|
||||
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
|
||||
queryMetricReq.getFilters().add(Filter.builder().name("imp_date")
|
||||
.operator(FilterOperatorEnum.MINOR_THAN_EQUALS).relation(Filter.Relation.FILTER)
|
||||
.value(LocalDate.now().toString()).build());
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(6, queryResp.getResultList().size());
|
||||
}
|
||||
|
||||
@Test
|
||||
@SetSystemProperty(key = "s2.test", value = "true")
|
||||
public void testWithMetricAndDimensionNames() throws Exception {
|
||||
QueryMetricReq queryMetricReq = new QueryMetricReq();
|
||||
queryMetricReq.setMetricNames(Arrays.asList("停留时长", "访问次数"));
|
||||
queryMetricReq.setDimensionNames(Arrays.asList("用户", "部门"));
|
||||
queryMetricReq.getFilters()
|
||||
.add(Filter.builder().name("数据日期").operator(FilterOperatorEnum.MINOR_THAN_EQUALS)
|
||||
.relation(Filter.Relation.FILTER).value(LocalDate.now().toString())
|
||||
.build());
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(6, queryResp.getResultList().size());
|
||||
@@ -44,6 +56,9 @@ public class QueryByMetricTest extends BaseTest {
|
||||
queryMetricReq.setDomainId(1L);
|
||||
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
|
||||
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
|
||||
queryMetricReq.getFilters().add(Filter.builder().name("imp_date")
|
||||
.operator(FilterOperatorEnum.MINOR_THAN_EQUALS).relation(Filter.Relation.FILTER)
|
||||
.value(LocalDate.now().toString()).build());
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(6, queryResp.getResultList().size());
|
||||
@@ -61,6 +76,9 @@ public class QueryByMetricTest extends BaseTest {
|
||||
queryMetricReq.setDomainId(1L);
|
||||
queryMetricReq.setMetricIds(Arrays.asList(1L, 3L));
|
||||
queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L));
|
||||
queryMetricReq.getFilters().add(Filter.builder().name("imp_date")
|
||||
.operator(FilterOperatorEnum.MINOR_THAN_EQUALS).relation(Filter.Relation.FILTER)
|
||||
.value(LocalDate.now().toString()).build());
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(6, queryResp.getResultList().size());
|
||||
|
||||
@@ -18,7 +18,7 @@ public class TranslateTest extends BaseTest {
|
||||
public void testSqlExplain() throws Exception {
|
||||
String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
|
||||
SemanticTranslateResp explain = semanticLayerService.translate(
|
||||
QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()),
|
||||
QueryReqBuilder.buildS2SQLReq(sql, DataUtils.productDatasetId),
|
||||
User.getDefaultUser());
|
||||
assertNotNull(explain);
|
||||
assertNotNull(explain.getQuerySQL());
|
||||
|
||||
@@ -15,10 +15,15 @@ import static java.time.LocalDate.now;
|
||||
|
||||
public class DataUtils {
|
||||
|
||||
public static final Integer metricAgentId = 1;
|
||||
public static final Integer tagAgentId = 2;
|
||||
public static final Integer productAgentId = 1;
|
||||
public static final Integer companyAgentId = 2;
|
||||
public static final Integer singerAgentId = 3;
|
||||
|
||||
public static final Long productDatasetId = 1L;
|
||||
public static final Long companyDatasetId = 2L;
|
||||
public static final Long singerDatasettId = 3L;
|
||||
|
||||
public static final Integer ONE_TURNS_CHAT_ID = 10;
|
||||
public static final Integer MULTI_TURNS_CHAT_ID = 11;
|
||||
private static final User user_test = User.getDefaultUser();
|
||||
|
||||
public static User getUser() {
|
||||
@@ -40,7 +45,7 @@ public class DataUtils {
|
||||
public static ChatParseReq getChatParseReq(Integer id, String query, boolean enableLLM) {
|
||||
ChatParseReq chatParseReq = new ChatParseReq();
|
||||
chatParseReq.setQueryText(query);
|
||||
chatParseReq.setAgentId(metricAgentId);
|
||||
chatParseReq.setAgentId(productAgentId);
|
||||
chatParseReq.setChatId(id);
|
||||
chatParseReq.setUser(user_test);
|
||||
chatParseReq.setDisableLLM(!enableLLM);
|
||||
@@ -92,7 +97,4 @@ public class DataUtils {
|
||||
return result;
|
||||
}
|
||||
|
||||
public static Long getMetricAgentView() {
|
||||
return 1L;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,12 @@ public class LLMConfigUtils {
|
||||
OPENAI_GLM(false),
|
||||
OLLAMA_LLAMA3(true),
|
||||
OLLAMA_QWEN2(true),
|
||||
OLLAMA_QWEN25(true);
|
||||
OLLAMA_QWEN25_7B(true),
|
||||
OLLAMA_QWEN25_14B(true),
|
||||
OLLAMA_QWEN25_CODE_7B(true),
|
||||
OLLAMA_QWEN25_CODE_3B(true),
|
||||
OLLAMA_GLM4(true);
|
||||
|
||||
|
||||
public boolean isOllam;
|
||||
|
||||
@@ -35,10 +40,26 @@ public class LLMConfigUtils {
|
||||
baseUrl = "http://localhost:11434";
|
||||
modelName = "qwen2:7b";
|
||||
break;
|
||||
case OLLAMA_QWEN25:
|
||||
case OLLAMA_QWEN25_7B:
|
||||
baseUrl = "http://localhost:11434";
|
||||
modelName = "qwen2.5:7b";
|
||||
break;
|
||||
case OLLAMA_QWEN25_14B:
|
||||
baseUrl = "http://localhost:11434";
|
||||
modelName = "qwen2.5:14b";
|
||||
break;
|
||||
case OLLAMA_QWEN25_CODE_7B:
|
||||
baseUrl = "http://localhost:11434";
|
||||
modelName = "qwen2.5-coder:7b";
|
||||
break;
|
||||
case OLLAMA_QWEN25_CODE_3B:
|
||||
baseUrl = "http://localhost:11434";
|
||||
modelName = "qwen2.5-coder:3b";
|
||||
break;
|
||||
case OLLAMA_GLM4:
|
||||
baseUrl = "http://localhost:11434";
|
||||
modelName = "glm4:latest";
|
||||
break;
|
||||
case OPENAI_GLM:
|
||||
baseUrl = "https://open.bigmodel.cn/api/pas/v4/";
|
||||
apiKey = "REPLACE_WITH_YOUR_KEY";
|
||||
|
||||
@@ -1,14 +1,34 @@
|
||||
spring:
|
||||
datasource:
|
||||
driver-class-name: org.h2.Driver
|
||||
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false;QUERY_TIMEOUT=100
|
||||
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false;QUERY_TIMEOUT=30
|
||||
username: root
|
||||
password: semantic
|
||||
sql:
|
||||
init:
|
||||
schema-locations: classpath:db/schema-h2.sql
|
||||
data-locations: classpath:db/data-h2.sql
|
||||
schema-locations: classpath:db/schema-h2.sql,classpath:db/schema-h2-demo.sql
|
||||
data-locations: classpath:db/data-h2.sql,classpath:db/data-h2-demo.sql
|
||||
h2:
|
||||
console:
|
||||
path: /h2-console/semantic
|
||||
enabled: true
|
||||
enabled: true
|
||||
|
||||
### Comment out following lines if using MySQL
|
||||
#spring:
|
||||
# datasource:
|
||||
# driver-class-name: com.mysql.cj.jdbc.Driver
|
||||
# url: jdbc:mysql://localhost:3306/s2_database?user=root
|
||||
# username: root
|
||||
# password:
|
||||
# sql:
|
||||
# enabled: true
|
||||
# mode: always
|
||||
# username: root
|
||||
# password:
|
||||
# init:
|
||||
# schema-locations: classpath:db/schema-mysql.sql,classpath:db/schema-mysql-demo.sql
|
||||
# data-locations: classpath:db/data-mysql.sql,classpath:db/data-mysql-demo.sql
|
||||
# h2:
|
||||
# console:
|
||||
# path: /h2-console/semantic
|
||||
# enabled: true
|
||||
1083
launchers/standalone/src/test/resources/db/data-h2-demo.sql
Normal file
1083
launchers/standalone/src/test/resources/db/data-h2-demo.sql
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,85 @@
|
||||
-------S2VisitsDemo
|
||||
CREATE TABLE IF NOT EXISTS `s2_user_department` (
|
||||
`user_name` varchar(200) NOT NULL,
|
||||
`department` varchar(200) NOT NULL, -- department of user
|
||||
PRIMARY KEY (`user_name`,`department`)
|
||||
);
|
||||
COMMENT ON TABLE s2_user_department IS 'user_department_info';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `s2_pv_uv_statis` (
|
||||
`imp_date` varchar(200) NOT NULL,
|
||||
`user_name` varchar(200) NOT NULL,
|
||||
`page` varchar(200) NOT NULL
|
||||
);
|
||||
COMMENT ON TABLE s2_pv_uv_statis IS 's2_pv_uv_statis';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `s2_stay_time_statis` (
|
||||
`imp_date` varchar(200) NOT NULL,
|
||||
`user_name` varchar(200) NOT NULL,
|
||||
`stay_hours` DOUBLE NOT NULL,
|
||||
`page` varchar(200) NOT NULL
|
||||
);
|
||||
COMMENT ON TABLE s2_stay_time_statis IS 's2_stay_time_statis_info';
|
||||
|
||||
-------S2ArtistDemo
|
||||
CREATE TABLE IF NOT EXISTS `singer` (
|
||||
`singer_name` varchar(200) NOT NULL,
|
||||
`act_area` varchar(200) NOT NULL,
|
||||
`song_name` varchar(200) NOT NULL,
|
||||
`genre` varchar(200) NOT NULL,
|
||||
`js_play_cnt` bigINT DEFAULT NULL,
|
||||
`down_cnt` bigINT DEFAULT NULL,
|
||||
`favor_cnt` bigINT DEFAULT NULL,
|
||||
PRIMARY KEY (`singer_name`)
|
||||
);
|
||||
COMMENT ON TABLE singer IS 'singer_info';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `genre` (
|
||||
`g_name` varchar(20) NOT NULL , -- genre name
|
||||
`rating` INT ,
|
||||
`most_popular_in` varchar(50) ,
|
||||
PRIMARY KEY (`g_name`)
|
||||
);
|
||||
COMMENT ON TABLE genre IS 'genre';
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `artist` (
|
||||
`artist_name` varchar(50) NOT NULL , -- genre name
|
||||
`citizenship` varchar(20) ,
|
||||
`gender` varchar(20) ,
|
||||
`g_name` varchar(50),
|
||||
PRIMARY KEY (`artist_name`,`citizenship`)
|
||||
);
|
||||
COMMENT ON TABLE artist IS 'artist';
|
||||
|
||||
-------S2CompanyDemo
|
||||
CREATE TABLE IF NOT EXISTS `company` (
|
||||
`company_id` varchar(50) NOT NULL ,
|
||||
`company_name` varchar(50) NOT NULL ,
|
||||
`headquarter_address` varchar(50) NOT NULL ,
|
||||
`company_established_time` varchar(20) NOT NULL ,
|
||||
`founder` varchar(20) NOT NULL ,
|
||||
`ceo` varchar(20) NOT NULL ,
|
||||
`annual_turnover` bigint(15) ,
|
||||
`employee_count` int(7) ,
|
||||
PRIMARY KEY (`company_id`)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `brand` (
|
||||
`brand_id` varchar(50) NOT NULL ,
|
||||
`brand_name` varchar(50) NOT NULL ,
|
||||
`brand_established_time` varchar(20) NOT NULL ,
|
||||
`company_id` varchar(50) NOT NULL ,
|
||||
`legal_representative` varchar(20) NOT NULL ,
|
||||
`registered_capital` bigint(15) ,
|
||||
PRIMARY KEY (`brand_id`)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `brand_revenue` (
|
||||
`year_time` varchar(10) NOT NULL ,
|
||||
`brand_id` varchar(50) NOT NULL ,
|
||||
`revenue` bigint(15) NOT NULL,
|
||||
`profit` bigint(15) NOT NULL ,
|
||||
`revenue_growth_year_on_year` double NOT NULL ,
|
||||
`profit_growth_year_on_year` double NOT NULL
|
||||
);
|
||||
|
||||
@@ -21,7 +21,7 @@ s2:
|
||||
date: true
|
||||
|
||||
demo:
|
||||
names: S2VisitsDemo,S2SingerDemo
|
||||
names: S2VisitsDemo,S2SingerDemo,S2CompanyDemo
|
||||
enableLLM: false
|
||||
|
||||
authentication:
|
||||
|
||||
Reference in New Issue
Block a user