[improvement][launcher]Use API to get element ID avoiding hard-code.

This commit is contained in:
jerryjzhang
2024-11-27 22:26:30 +08:00
parent 25559fdaa5
commit 111304486b
5 changed files with 45 additions and 17 deletions

View File

@@ -87,7 +87,7 @@ public class SqlQueryConverter implements QueryConverter {
if (!SqlSelectFunctionHelper.hasAggregateFunction(sql) && !SqlSelectHelper.hasGroupBy(sql) if (!SqlSelectFunctionHelper.hasAggregateFunction(sql) && !SqlSelectHelper.hasGroupBy(sql)
&& !SqlSelectHelper.hasWith(sql) && !SqlSelectHelper.hasSubSelect(sql)) { && !SqlSelectHelper.hasWith(sql) && !SqlSelectHelper.hasSubSelect(sql)) {
log.debug("getAggOption simple sql set to DEFAULT"); log.debug("getAggOption simple sql set to DEFAULT");
return AggOption.DEFAULT; return AggOption.NATIVE;
} }
// if there is no group by in S2SQL,set MetricTable's aggOption to "NATIVE" // if there is no group by in S2SQL,set MetricTable's aggOption to "NATIVE"
@@ -107,7 +107,7 @@ public class SqlQueryConverter implements QueryConverter {
.count(); .count();
if (defaultAggNullCnt > 0) { if (defaultAggNullCnt > 0) {
log.debug("getAggOption find null defaultAgg metric set to NATIVE"); log.debug("getAggOption find null defaultAgg metric set to NATIVE");
return AggOption.OUTER; return AggOption.DEFAULT;
} }
return AggOption.DEFAULT; return AggOption.DEFAULT;
} }

View File

@@ -13,12 +13,14 @@ import com.tencent.supersonic.common.service.ChatModelService;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.server.service.SchemaService;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import java.time.LocalDate; import java.time.LocalDate;
import java.util.List; import java.util.List;
import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -37,6 +39,8 @@ public class BaseTest extends BaseApplication {
protected AgentService agentService; protected AgentService agentService;
@Autowired @Autowired
protected ChatModelService chatModelService; protected ChatModelService chatModelService;
@Autowired
protected SchemaService schemaService;
@Value("${s2.demo.enableLLM:false}") @Value("${s2.demo.enableLLM:false}")
protected boolean enableLLM; protected boolean enableLLM;
@@ -106,4 +110,10 @@ public class BaseTest extends BaseApplication {
assertEquals(expectedParseInfo.getDateInfo(), actualParseInfo.getDateInfo()); assertEquals(expectedParseInfo.getDateInfo(), actualParseInfo.getDateInfo());
} }
protected SchemaElement getSchemaElementByName(Set<SchemaElement> elementSet, String name) {
Optional<SchemaElement> matchElement =
elementSet.stream().filter(e -> e.getName().equals(name)).findFirst();
return matchElement.orElse(null);
}
} }

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
@@ -12,6 +13,7 @@ import com.tencent.supersonic.headless.chat.query.rule.detail.DetailDimensionQue
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.SetSystemProperty;
import org.springframework.boot.test.context.SpringBootTest; import org.springframework.boot.test.context.SpringBootTest;
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@@ -19,6 +21,7 @@ import org.springframework.boot.test.context.SpringBootTest;
public class DetailTest extends BaseTest { public class DetailTest extends BaseTest {
@Test @Test
@SetSystemProperty(key = "s2.test", value = "true")
public void test_detail_dimension() throws Exception { public void test_detail_dimension() throws Exception {
QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.singerAgentId); QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.singerAgentId);
@@ -30,8 +33,11 @@ public class DetailTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.DETAIL); expectedParseInfo.setQueryType(QueryType.DETAIL);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE); expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter = DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.singerDatasettId);
DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 17L); SchemaElement singerElement = getSchemaElementByName(schema.getDimensions(), "歌手名");
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS,
"周杰伦", "歌手名", singerElement.getId());
expectedParseInfo.getDimensionFilters().add(dimensionFilter); expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.getDimensions() expectedParseInfo.getDimensions()
@@ -53,8 +59,10 @@ public class DetailTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.DETAIL); expectedParseInfo.setQueryType(QueryType.DETAIL);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE); expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter = DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.singerDatasettId);
DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, "国风", "流派", 7L); SchemaElement genreElement = getSchemaElementByName(schema.getDimensions(), "流派");
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, "国风",
"流派", genreElement.getId());
expectedParseInfo.getDimensionFilters().add(dimensionFilter); expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.getDimensions() expectedParseInfo.getDimensions()
.addAll(Lists.newArrayList(SchemaElement.builder().name("歌手名").build())); .addAll(Lists.newArrayList(SchemaElement.builder().name("歌手名").build()));

View File

@@ -5,6 +5,8 @@ import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum; import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery;
@@ -61,8 +63,11 @@ public class MetricTest extends BaseTest {
expectedParseInfo.setAggType(NONE); expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensionFilters().add(
DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId);
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId()));
expectedParseInfo.setDateInfo( expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay)); DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
@@ -111,8 +116,11 @@ public class MetricTest extends BaseTest {
List<String> list = new ArrayList<>(); List<String> list = new ArrayList<>();
list.add("alice"); list.add("alice");
list.add("lucy"); list.add("lucy");
QueryFilter dimensionFilter =
DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list, "用户", 2L); DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId);
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
QueryFilter dimensionFilter = DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list,
"用户", userElement.getId());
expectedParseInfo.getDimensionFilters().add(dimensionFilter); expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.setDateInfo( expectedParseInfo.setDateInfo(
@@ -182,9 +190,11 @@ public class MetricTest extends BaseTest {
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE); expectedParseInfo.setAggType(NONE);
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId);
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensionFilters().add( expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId()));
expectedParseInfo.setDateInfo( expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 1, period, startDay, startDay)); DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 1, period, startDay, startDay));