diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlQueryConverter.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlQueryConverter.java index f5d1a7696..d9ad61187 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlQueryConverter.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlQueryConverter.java @@ -87,7 +87,7 @@ public class SqlQueryConverter implements QueryConverter { if (!SqlSelectFunctionHelper.hasAggregateFunction(sql) && !SqlSelectHelper.hasGroupBy(sql) && !SqlSelectHelper.hasWith(sql) && !SqlSelectHelper.hasSubSelect(sql)) { log.debug("getAggOption simple sql set to DEFAULT"); - return AggOption.DEFAULT; + return AggOption.NATIVE; } // if there is no group by in S2SQL,set MetricTable's aggOption to "NATIVE" @@ -107,7 +107,7 @@ public class SqlQueryConverter implements QueryConverter { .count(); if (defaultAggNullCnt > 0) { log.debug("getAggOption find null defaultAgg metric set to NATIVE"); - return AggOption.OUTER; + return AggOption.DEFAULT; } return AggOption.DEFAULT; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/SqlQueryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/SqlQueryApiController.java index 5e2c46f4d..f8e2af71a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/SqlQueryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/SqlQueryApiController.java @@ -39,7 +39,7 @@ public class SqlQueryApiController { @PostMapping("/sql") public Object queryBySql(@RequestBody QuerySqlReq querySqlReq, HttpServletRequest request, - HttpServletResponse response) throws Exception { + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); String sql = querySqlReq.getSql(); querySqlReq.setSql(StringUtil.replaceBackticks(sql)); @@ -49,7 +49,7 @@ public class SqlQueryApiController { @PostMapping("/sqls") public Object queryBySqls(@RequestBody QuerySqlsReq querySqlsReq, HttpServletRequest request, - HttpServletResponse response) throws Exception { + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); List semanticQueryReqs = querySqlsReq.getSqls().stream().map(sql -> { QuerySqlReq querySqlReq = new QuerySqlReq(); @@ -73,7 +73,7 @@ public class SqlQueryApiController { @PostMapping("/sqlsWithException") public Object queryBySqlsWithException(@RequestBody QuerySqlsReq querySqlsReq, - HttpServletRequest request, HttpServletResponse response) throws Exception { + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); List semanticQueryReqs = querySqlsReq.getSqls().stream().map(sql -> { QuerySqlReq querySqlReq = new QuerySqlReq(); @@ -97,7 +97,7 @@ public class SqlQueryApiController { @PostMapping("/validate") public Object validate(@RequestBody QuerySqlReq querySqlReq, HttpServletRequest request, - HttpServletResponse response) throws Exception { + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); String sql = querySqlReq.getSql(); querySqlReq.setSql(StringUtil.replaceBackticks(sql)); @@ -106,7 +106,7 @@ public class SqlQueryApiController { @PostMapping("/validateAndQuery") public Object validateAndQuery(@RequestBody QuerySqlsReq querySqlsReq, - HttpServletRequest request, HttpServletResponse response) throws Exception { + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); List convert = convert(querySqlsReq); for (QuerySqlReq querySqlReq : convert) { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java index cab07b66c..d92be2031 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java @@ -13,12 +13,14 @@ import com.tencent.supersonic.common.service.ChatModelService; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.response.QueryState; +import com.tencent.supersonic.headless.server.service.SchemaService; import com.tencent.supersonic.util.DataUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import java.time.LocalDate; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -37,6 +39,8 @@ public class BaseTest extends BaseApplication { protected AgentService agentService; @Autowired protected ChatModelService chatModelService; + @Autowired + protected SchemaService schemaService; @Value("${s2.demo.enableLLM:false}") protected boolean enableLLM; @@ -106,4 +110,10 @@ public class BaseTest extends BaseApplication { assertEquals(expectedParseInfo.getDateInfo(), actualParseInfo.getDateInfo()); } + + protected SchemaElement getSchemaElementByName(Set elementSet, String name) { + Optional matchElement = + elementSet.stream().filter(e -> e.getName().equals(name)).findFirst(); + return matchElement.orElse(null); + } } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/DetailTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/DetailTest.java index 966943e10..dcb5a7865 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/DetailTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/DetailTest.java @@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.QueryType; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; @@ -12,6 +13,7 @@ import com.tencent.supersonic.headless.chat.query.rule.detail.DetailDimensionQue import com.tencent.supersonic.util.DataUtils; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; +import org.junitpioneer.jupiter.SetSystemProperty; import org.springframework.boot.test.context.SpringBootTest; @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) @@ -19,6 +21,7 @@ import org.springframework.boot.test.context.SpringBootTest; public class DetailTest extends BaseTest { @Test + @SetSystemProperty(key = "s2.test", value = "true") public void test_detail_dimension() throws Exception { QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.singerAgentId); @@ -30,8 +33,11 @@ public class DetailTest extends BaseTest { expectedParseInfo.setQueryType(QueryType.DETAIL); expectedParseInfo.setAggType(AggregateTypeEnum.NONE); - QueryFilter dimensionFilter = - DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 17L); + DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.singerDatasettId); + SchemaElement singerElement = getSchemaElementByName(schema.getDimensions(), "歌手名"); + + QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, + "周杰伦", "歌手名", singerElement.getId()); expectedParseInfo.getDimensionFilters().add(dimensionFilter); expectedParseInfo.getDimensions() @@ -53,8 +59,10 @@ public class DetailTest extends BaseTest { expectedParseInfo.setQueryType(QueryType.DETAIL); expectedParseInfo.setAggType(AggregateTypeEnum.NONE); - QueryFilter dimensionFilter = - DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, "国风", "流派", 7L); + DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.singerDatasettId); + SchemaElement genreElement = getSchemaElementByName(schema.getDimensions(), "流派"); + QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, "国风", + "流派", genreElement.getId()); expectedParseInfo.getDimensionFilters().add(dimensionFilter); expectedParseInfo.getDimensions() .addAll(Lists.newArrayList(SchemaElement.builder().name("歌手名").build())); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java index 0dbc119f9..e7c2683ad 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java @@ -5,6 +5,8 @@ import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.QueryType; +import com.tencent.supersonic.headless.api.pojo.DataSetSchema; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery; @@ -61,8 +63,11 @@ public class MetricTest extends BaseTest { expectedParseInfo.setAggType(NONE); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); - expectedParseInfo.getDimensionFilters().add( - DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); + + DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId); + SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户"); + expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name", + FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId())); expectedParseInfo.setDateInfo( DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay)); @@ -111,8 +116,11 @@ public class MetricTest extends BaseTest { List list = new ArrayList<>(); list.add("alice"); list.add("lucy"); - QueryFilter dimensionFilter = - DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list, "用户", 2L); + + DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId); + SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户"); + QueryFilter dimensionFilter = DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list, + "用户", userElement.getId()); expectedParseInfo.getDimensionFilters().add(dimensionFilter); expectedParseInfo.setDateInfo( @@ -182,9 +190,11 @@ public class MetricTest extends BaseTest { expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); expectedParseInfo.setAggType(NONE); + DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId); + SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户"); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); - expectedParseInfo.getDimensionFilters().add( - DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); + expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name", + FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId())); expectedParseInfo.setDateInfo( DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 1, period, startDay, startDay));