mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-20 06:34:55 +00:00
[improvement][headless]Clean code logic of headless translator.
This commit is contained in:
@@ -7,26 +7,12 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.DatasetTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.DataInterpretProcessor;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.JoinCondition;
|
||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||
import com.tencent.supersonic.common.pojo.enums.*;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.headless.api.pojo.AggregateTypeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.Dim;
|
||||
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
|
||||
import com.tencent.supersonic.headless.api.pojo.Identify;
|
||||
import com.tencent.supersonic.headless.api.pojo.Measure;
|
||||
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.*;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DataSetReq;
|
||||
@@ -40,11 +26,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.core.annotation.Order;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@@ -272,7 +254,6 @@ public class S2CompanyDemo extends S2BaseDemo {
|
||||
Map<String, ChatApp> chatAppConfig =
|
||||
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
|
||||
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
|
||||
chatAppConfig.get(DataInterpretProcessor.APP_KEY).setEnable(true);
|
||||
agent.setChatAppConfig(chatAppConfig);
|
||||
|
||||
agentService.createAgent(agent, defaultUser);
|
||||
|
||||
@@ -146,7 +146,8 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
agent.setStatus(1);
|
||||
agent.setEnableSearch(1);
|
||||
agent.setExamples(Lists.newArrayList("近15天超音数访问次数汇总", "按部门统计超音数的访问人数", "对比alice和lucy的停留时长",
|
||||
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长"));
|
||||
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长",
|
||||
"今年以来访问次数最高的一天是哪一天"));
|
||||
|
||||
// configure tools
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
|
||||
@@ -26,9 +26,10 @@ com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver=\
|
||||
|
||||
com.tencent.supersonic.headless.core.translator.converter.QueryConverter=\
|
||||
com.tencent.supersonic.headless.core.translator.converter.DefaultDimValueConverter,\
|
||||
com.tencent.supersonic.headless.core.translator.converter.SqlVariableParseConverter,\
|
||||
com.tencent.supersonic.headless.core.translator.converter.CalculateAggConverter,\
|
||||
com.tencent.supersonic.headless.core.translator.converter.ParserDefaultConverter
|
||||
com.tencent.supersonic.headless.core.translator.converter.SqlVariableConverter,\
|
||||
com.tencent.supersonic.headless.core.translator.converter.MetricRatioConverter,\
|
||||
com.tencent.supersonic.headless.core.translator.converter.SqlQueryConverter,\
|
||||
com.tencent.supersonic.headless.core.translator.converter.StructQueryConverter
|
||||
|
||||
com.tencent.supersonic.headless.core.translator.QueryOptimizer=\
|
||||
com.tencent.supersonic.headless.core.translator.DetailQueryOptimizer
|
||||
|
||||
@@ -12,7 +12,6 @@ import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
|
||||
import com.tencent.supersonic.common.service.ChatModelService;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
|
||||
@@ -20,7 +20,7 @@ public class DetailTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void test_detail_dimension() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.tagAgentId);
|
||||
QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.singerAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
@@ -31,7 +31,7 @@ public class DetailTest extends BaseTest {
|
||||
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
|
||||
|
||||
QueryFilter dimensionFilter =
|
||||
DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 8L);
|
||||
DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 17L);
|
||||
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
|
||||
expectedParseInfo.getDimensions()
|
||||
@@ -43,7 +43,7 @@ public class DetailTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void test_detail_filter() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("国风歌手", DataUtils.tagAgentId);
|
||||
QueryResult actualResult = submitNewChat("国风歌手", DataUtils.singerAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
|
||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricGroupByQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricTopNQuery;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.jupiter.api.Order;
|
||||
@@ -28,13 +29,28 @@ import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
|
||||
public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetric() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("超音数 访问次数", DataUtils.metricAgentId);
|
||||
public void testMetricModel() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("超音数 访问次数", DataUtils.productAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(MetricModelQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||
|
||||
expectedParseInfo.setDateInfo(
|
||||
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
|
||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
assert actualResult.getQueryResults().size() == 1;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMetricFilter() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);
|
||||
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.productAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
@@ -57,7 +73,8 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricGroupBy() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数", DataUtils.metricAgentId);
|
||||
System.setProperty("s2.test", "true");
|
||||
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数", DataUtils.productAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
@@ -79,7 +96,7 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricFilterCompare() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.metricAgentId);
|
||||
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.productAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
@@ -107,7 +124,7 @@ public class MetricTest extends BaseTest {
|
||||
@Test
|
||||
@Order(3)
|
||||
public void testMetricTopN() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.metricAgentId);
|
||||
QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.productAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
@@ -128,7 +145,7 @@ public class MetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testMetricGroupBySum() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数总和", DataUtils.metricAgentId);
|
||||
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数总和", DataUtils.productAgentId);
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
@@ -154,7 +171,7 @@ public class MetricTest extends BaseTest {
|
||||
String dateStr = textFormat.format(format.parse(startDay));
|
||||
|
||||
QueryResult actualResult =
|
||||
submitNewChat(String.format("alice在%s的访问次数", dateStr), DataUtils.metricAgentId);
|
||||
submitNewChat(String.format("alice在%s的访问次数", dateStr), DataUtils.productAgentId);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
|
||||
@@ -20,6 +20,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testWithMetricAndDimensionBizNames() throws Exception {
|
||||
System.setProperty("s2.test", "true");
|
||||
QueryMetricReq queryMetricReq = new QueryMetricReq();
|
||||
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
|
||||
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
|
||||
|
||||
@@ -46,6 +46,7 @@ public class QueryByStructTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testDetailQuery() throws Exception {
|
||||
System.setProperty("s2.test", "true");
|
||||
QueryStructReq queryStructReq =
|
||||
buildQueryStructReq(Arrays.asList("user_name", "department"), QueryType.DETAIL);
|
||||
SemanticQueryResp semanticQueryResp =
|
||||
@@ -86,6 +87,7 @@ public class QueryByStructTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
public void testFilterQuery() throws Exception {
|
||||
System.setProperty("s2.test", "true");
|
||||
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
|
||||
List<Filter> dimensionFilters = new ArrayList<>();
|
||||
Filter filter = new Filter();
|
||||
|
||||
@@ -15,10 +15,10 @@ import static java.time.LocalDate.now;
|
||||
|
||||
public class DataUtils {
|
||||
|
||||
public static final Integer metricAgentId = 1;
|
||||
public static final Integer tagAgentId = 2;
|
||||
public static final Integer productAgentId = 1;
|
||||
public static final Integer companyAgentId = 2;
|
||||
public static final Integer singerAgentId = 3;
|
||||
public static final Integer ONE_TURNS_CHAT_ID = 10;
|
||||
public static final Integer MULTI_TURNS_CHAT_ID = 11;
|
||||
private static final User user_test = User.getDefaultUser();
|
||||
|
||||
public static User getUser() {
|
||||
@@ -40,7 +40,7 @@ public class DataUtils {
|
||||
public static ChatParseReq getChatParseReq(Integer id, String query, boolean enableLLM) {
|
||||
ChatParseReq chatParseReq = new ChatParseReq();
|
||||
chatParseReq.setQueryText(query);
|
||||
chatParseReq.setAgentId(metricAgentId);
|
||||
chatParseReq.setAgentId(productAgentId);
|
||||
chatParseReq.setChatId(id);
|
||||
chatParseReq.setUser(user_test);
|
||||
chatParseReq.setDisableLLM(!enableLLM);
|
||||
|
||||
@@ -21,7 +21,7 @@ s2:
|
||||
date: true
|
||||
|
||||
demo:
|
||||
names: S2VisitsDemo,S2SingerDemo
|
||||
names: S2VisitsDemo,S2SingerDemo,S2CompanyDemo
|
||||
enableLLM: false
|
||||
|
||||
authentication:
|
||||
|
||||
Reference in New Issue
Block a user