[improvement][headless]Clean code logic of headless translator.

This commit is contained in:
jerryjzhang
2024-11-24 19:07:56 +08:00
parent c22e3ef2e8
commit 860fd5d299
45 changed files with 795 additions and 1058 deletions

View File

@@ -12,7 +12,6 @@ import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.service.ChatModelService;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.util.DataUtils;
import org.springframework.beans.factory.annotation.Autowired;

View File

@@ -20,7 +20,7 @@ public class DetailTest extends BaseTest {
@Test
public void test_detail_dimension() throws Exception {
QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.tagAgentId);
QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.singerAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -31,7 +31,7 @@ public class DetailTest extends BaseTest {
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter =
DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 8L);
DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 17L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.getDimensions()
@@ -43,7 +43,7 @@ public class DetailTest extends BaseTest {
@Test
public void test_detail_filter() throws Exception {
QueryResult actualResult = submitNewChat("国风歌手", DataUtils.tagAgentId);
QueryResult actualResult = submitNewChat("国风歌手", DataUtils.singerAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricGroupByQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricTopNQuery;
import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.Order;
@@ -28,13 +29,28 @@ import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
public class MetricTest extends BaseTest {
@Test
public void testMetric() throws Exception {
QueryResult actualResult = submitNewChat("超音数 访问次数", DataUtils.metricAgentId);
public void testMetricModel() throws Exception {
QueryResult actualResult = submitNewChat("超音数 访问次数", DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(MetricModelQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
assertQueryResult(expectedResult, actualResult);
assert actualResult.getQueryResults().size() == 1;
}
@Test
public void testMetricFilter() throws Exception {
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -57,7 +73,8 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricGroupBy() throws Exception {
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数", DataUtils.metricAgentId);
System.setProperty("s2.test", "true");
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数", DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -79,7 +96,7 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricFilterCompare() throws Exception {
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.metricAgentId);
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -107,7 +124,7 @@ public class MetricTest extends BaseTest {
@Test
@Order(3)
public void testMetricTopN() throws Exception {
QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.metricAgentId);
QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -128,7 +145,7 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricGroupBySum() throws Exception {
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数总和", DataUtils.metricAgentId);
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数总和", DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
@@ -154,7 +171,7 @@ public class MetricTest extends BaseTest {
String dateStr = textFormat.format(format.parse(startDay));
QueryResult actualResult =
submitNewChat(String.format("alice在%s的访问次数", dateStr), DataUtils.metricAgentId);
submitNewChat(String.format("alice在%s的访问次数", dateStr), DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();

View File

@@ -20,6 +20,7 @@ public class QueryByMetricTest extends BaseTest {
@Test
public void testWithMetricAndDimensionBizNames() throws Exception {
System.setProperty("s2.test", "true");
QueryMetricReq queryMetricReq = new QueryMetricReq();
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));

View File

@@ -46,6 +46,7 @@ public class QueryByStructTest extends BaseTest {
@Test
public void testDetailQuery() throws Exception {
System.setProperty("s2.test", "true");
QueryStructReq queryStructReq =
buildQueryStructReq(Arrays.asList("user_name", "department"), QueryType.DETAIL);
SemanticQueryResp semanticQueryResp =
@@ -86,6 +87,7 @@ public class QueryByStructTest extends BaseTest {
@Test
public void testFilterQuery() throws Exception {
System.setProperty("s2.test", "true");
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
List<Filter> dimensionFilters = new ArrayList<>();
Filter filter = new Filter();

View File

@@ -15,10 +15,10 @@ import static java.time.LocalDate.now;
public class DataUtils {
public static final Integer metricAgentId = 1;
public static final Integer tagAgentId = 2;
public static final Integer productAgentId = 1;
public static final Integer companyAgentId = 2;
public static final Integer singerAgentId = 3;
public static final Integer ONE_TURNS_CHAT_ID = 10;
public static final Integer MULTI_TURNS_CHAT_ID = 11;
private static final User user_test = User.getDefaultUser();
public static User getUser() {
@@ -40,7 +40,7 @@ public class DataUtils {
public static ChatParseReq getChatParseReq(Integer id, String query, boolean enableLLM) {
ChatParseReq chatParseReq = new ChatParseReq();
chatParseReq.setQueryText(query);
chatParseReq.setAgentId(metricAgentId);
chatParseReq.setAgentId(productAgentId);
chatParseReq.setChatId(id);
chatParseReq.setUser(user_test);
chatParseReq.setDisableLLM(!enableLLM);

View File

@@ -21,7 +21,7 @@ s2:
date: true
demo:
names: S2VisitsDemo,S2SingerDemo
names: S2VisitsDemo,S2SingerDemo,S2CompanyDemo
enableLLM: false
authentication: