[improvement][launcher]Use API to get element ID avoiding hard-code.

This commit is contained in:
jerryjzhang
2024-11-27 22:26:30 +08:00
parent 25559fdaa5
commit 111304486b
5 changed files with 45 additions and 17 deletions

View File

@@ -87,7 +87,7 @@ public class SqlQueryConverter implements QueryConverter {
if (!SqlSelectFunctionHelper.hasAggregateFunction(sql) && !SqlSelectHelper.hasGroupBy(sql)
&& !SqlSelectHelper.hasWith(sql) && !SqlSelectHelper.hasSubSelect(sql)) {
log.debug("getAggOption simple sql set to DEFAULT");
return AggOption.DEFAULT;
return AggOption.NATIVE;
}
// if there is no group by in S2SQL,set MetricTable's aggOption to "NATIVE"
@@ -107,7 +107,7 @@ public class SqlQueryConverter implements QueryConverter {
.count();
if (defaultAggNullCnt > 0) {
log.debug("getAggOption find null defaultAgg metric set to NATIVE");
return AggOption.OUTER;
return AggOption.DEFAULT;
}
return AggOption.DEFAULT;
}

View File

@@ -39,7 +39,7 @@ public class SqlQueryApiController {
@PostMapping("/sql")
public Object queryBySql(@RequestBody QuerySqlReq querySqlReq, HttpServletRequest request,
HttpServletResponse response) throws Exception {
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
String sql = querySqlReq.getSql();
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
@@ -49,7 +49,7 @@ public class SqlQueryApiController {
@PostMapping("/sqls")
public Object queryBySqls(@RequestBody QuerySqlsReq querySqlsReq, HttpServletRequest request,
HttpServletResponse response) throws Exception {
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
List<SemanticQueryReq> semanticQueryReqs = querySqlsReq.getSqls().stream().map(sql -> {
QuerySqlReq querySqlReq = new QuerySqlReq();
@@ -73,7 +73,7 @@ public class SqlQueryApiController {
@PostMapping("/sqlsWithException")
public Object queryBySqlsWithException(@RequestBody QuerySqlsReq querySqlsReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
HttpServletRequest request, HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
List<SemanticQueryReq> semanticQueryReqs = querySqlsReq.getSqls().stream().map(sql -> {
QuerySqlReq querySqlReq = new QuerySqlReq();
@@ -97,7 +97,7 @@ public class SqlQueryApiController {
@PostMapping("/validate")
public Object validate(@RequestBody QuerySqlReq querySqlReq, HttpServletRequest request,
HttpServletResponse response) throws Exception {
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
String sql = querySqlReq.getSql();
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
@@ -106,7 +106,7 @@ public class SqlQueryApiController {
@PostMapping("/validateAndQuery")
public Object validateAndQuery(@RequestBody QuerySqlsReq querySqlsReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
HttpServletRequest request, HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
List<QuerySqlReq> convert = convert(querySqlsReq);
for (QuerySqlReq querySqlReq : convert) {

View File

@@ -13,12 +13,14 @@ import com.tencent.supersonic.common.service.ChatModelService;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.server.service.SchemaService;
import com.tencent.supersonic.util.DataUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import java.time.LocalDate;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@@ -37,6 +39,8 @@ public class BaseTest extends BaseApplication {
protected AgentService agentService;
@Autowired
protected ChatModelService chatModelService;
@Autowired
protected SchemaService schemaService;
@Value("${s2.demo.enableLLM:false}")
protected boolean enableLLM;
@@ -106,4 +110,10 @@ public class BaseTest extends BaseApplication {
assertEquals(expectedParseInfo.getDateInfo(), actualParseInfo.getDateInfo());
}
protected SchemaElement getSchemaElementByName(Set<SchemaElement> elementSet, String name) {
Optional<SchemaElement> matchElement =
elementSet.stream().filter(e -> e.getName().equals(name)).findFirst();
return matchElement.orElse(null);
}
}

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
@@ -12,6 +13,7 @@ import com.tencent.supersonic.headless.chat.query.rule.detail.DetailDimensionQue
import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.SetSystemProperty;
import org.springframework.boot.test.context.SpringBootTest;
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@@ -19,6 +21,7 @@ import org.springframework.boot.test.context.SpringBootTest;
public class DetailTest extends BaseTest {
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void test_detail_dimension() throws Exception {
QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.singerAgentId);
@@ -30,8 +33,11 @@ public class DetailTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.DETAIL);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter =
DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 17L);
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.singerDatasettId);
SchemaElement singerElement = getSchemaElementByName(schema.getDimensions(), "歌手名");
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS,
"周杰伦", "歌手名", singerElement.getId());
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.getDimensions()
@@ -53,8 +59,10 @@ public class DetailTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.DETAIL);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter =
DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, "国风", "流派", 7L);
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.singerDatasettId);
SchemaElement genreElement = getSchemaElementByName(schema.getDimensions(), "流派");
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, "国风",
"流派", genreElement.getId());
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.getDimensions()
.addAll(Lists.newArrayList(SchemaElement.builder().name("歌手名").build()));

View File

@@ -5,6 +5,8 @@ import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery;
@@ -61,8 +63,11 @@ public class MetricTest extends BaseTest {
expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensionFilters().add(
DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId);
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId()));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
@@ -111,8 +116,11 @@ public class MetricTest extends BaseTest {
List<String> list = new ArrayList<>();
list.add("alice");
list.add("lucy");
QueryFilter dimensionFilter =
DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list, "用户", 2L);
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId);
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
QueryFilter dimensionFilter = DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list,
"用户", userElement.getId());
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.setDateInfo(
@@ -182,9 +190,11 @@ public class MetricTest extends BaseTest {
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId);
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensionFilters().add(
DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId()));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 1, period, startDay, startDay));