[fix]Fix unit test cases.
Some checks failed
supersonic CentOS CI / build (21) (push) Has been cancelled
supersonic mac CI / build (21) (push) Has been cancelled
supersonic ubuntu CI / build (21) (push) Has been cancelled
supersonic windows CI / build (21) (push) Has been cancelled

This commit is contained in:
jerryjzhang
2025-08-05 17:53:58 +08:00
parent af28bc7c2a
commit 1f6d217b26
4 changed files with 17 additions and 15 deletions

View File

@@ -22,15 +22,17 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
OpenAiChatModel.OpenAiChatModelBuilder openAiChatModelBuilder = OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl()) OpenAiChatModel.OpenAiChatModelBuilder openAiChatModelBuilder = OpenAiChatModel.builder()
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt()) .baseUrl(modelConfig.getBaseUrl()).modelName(modelConfig.getModelName())
.apiVersion(modelConfig.getApiVersion()).temperature(modelConfig.getTemperature()) .apiKey(modelConfig.keyDecrypt()).apiVersion(modelConfig.getApiVersion())
.topP(modelConfig.getTopP()).maxRetries(modelConfig.getMaxRetries()) .temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
.maxRetries(modelConfig.getMaxRetries())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut())) .timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.logRequests(modelConfig.getLogRequests()) .logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses()); .logResponses(modelConfig.getLogResponses());
if (modelConfig.getJsonFormat()) { if (modelConfig.getJsonFormat()) {
openAiChatModelBuilder.strictJsonSchema(true).responseFormat(modelConfig.getJsonFormatType()); openAiChatModelBuilder.strictJsonSchema(true)
.responseFormat(modelConfig.getJsonFormatType());
} }
return openAiChatModelBuilder.build(); return openAiChatModelBuilder.build();
} }

View File

@@ -87,7 +87,8 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
ChatModelConfig chatModelConfig = chatApp.getChatModelConfig(); ChatModelConfig chatModelConfig = chatApp.getChatModelConfig();
if (!StringUtils.isBlank(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE))) { if (!StringUtils.isBlank(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE))) {
chatModelConfig.setJsonFormat(true); chatModelConfig.setJsonFormat(true);
chatModelConfig.setJsonFormatType(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE)); chatModelConfig
.setJsonFormatType(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE));
} }
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatModelConfig); ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatModelConfig);
SemanticSqlExtractor extractor = SemanticSqlExtractor extractor =

View File

@@ -82,6 +82,7 @@ public class BaseTest extends BaseApplication {
protected SemanticQueryReq buildQuerySqlReq(String sql) { protected SemanticQueryReq buildQuerySqlReq(String sql) {
QuerySqlReq querySqlCmd = new QuerySqlReq(); QuerySqlReq querySqlCmd = new QuerySqlReq();
querySqlCmd.setSql(sql); querySqlCmd.setSql(sql);
querySqlCmd.getSqlInfo().setCorrectedS2SQL(sql);
querySqlCmd.setModelIds(DataUtils.getMetricAgentIModelIds()); querySqlCmd.setModelIds(DataUtils.getMetricAgentIModelIds());
return querySqlCmd; return querySqlCmd;
} }

View File

@@ -33,8 +33,7 @@ public class QueryBySqlTest extends BaseTest {
@Test @Test
@SetSystemProperty(key = "s2.test", value = "true") @SetSystemProperty(key = "s2.test", value = "true")
public void testSumQuery() throws Exception { public void testSumQuery() throws Exception {
SemanticQueryResp semanticQueryResp = SemanticQueryResp semanticQueryResp = queryBySql("SELECT SUM(访问次数) AS 总访问次数 FROM 超音数数据集 ");
queryBySql("SELECT SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 ");
assertEquals(1, semanticQueryResp.getColumns().size()); assertEquals(1, semanticQueryResp.getColumns().size());
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0); QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
@@ -45,7 +44,7 @@ public class QueryBySqlTest extends BaseTest {
@Test @Test
public void testGroupByQuery() throws Exception { public void testGroupByQuery() throws Exception {
SemanticQueryResp result = SemanticQueryResp result =
queryBySql("SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "); queryBySql("SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数数据集 GROUP BY 部门 ");
assertEquals(2, result.getColumns().size()); assertEquals(2, result.getColumns().size());
QueryColumn firstColumn = result.getColumns().get(0); QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1); QueryColumn secondColumn = result.getColumns().get(1);
@@ -56,8 +55,8 @@ public class QueryBySqlTest extends BaseTest {
@Test @Test
public void testFilterQuery() throws Exception { public void testFilterQuery() throws Exception {
SemanticQueryResp result = queryBySql( SemanticQueryResp result =
"SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 WHERE 部门 ='HR' GROUP BY 部门 "); queryBySql("SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数数据集 WHERE 部门 ='HR' GROUP BY 部门 ");
assertEquals(2, result.getColumns().size()); assertEquals(2, result.getColumns().size());
QueryColumn firstColumn = result.getColumns().get(0); QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1); QueryColumn secondColumn = result.getColumns().get(1);
@@ -71,8 +70,7 @@ public class QueryBySqlTest extends BaseTest {
public void testDateSumQuery() throws Exception { public void testDateSumQuery() throws Exception {
String startDate = now().plusDays(-365).toString(); String startDate = now().plusDays(-365).toString();
String endDate = now().plusDays(0).toString(); String endDate = now().plusDays(0).toString();
String sql = String sql = "SELECT SUM(访问次数) AS 总访问次数 FROM 超音数数据集 WHERE 数据日期 >= '%s' AND 数据日期 <= '%s' ";
"SELECT SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 WHERE 数据日期 >= '%s' AND 数据日期 <= '%s' ";
SemanticQueryResp semanticQueryResp = queryBySql(String.format(sql, startDate, endDate)); SemanticQueryResp semanticQueryResp = queryBySql(String.format(sql, startDate, endDate));
assertEquals(1, semanticQueryResp.getColumns().size()); assertEquals(1, semanticQueryResp.getColumns().size());
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0); QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
@@ -82,9 +80,9 @@ public class QueryBySqlTest extends BaseTest {
@Test @Test
public void testCacheQuery() throws Exception { public void testCacheQuery() throws Exception {
queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "); queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数数据集 GROUP BY 部门 ");
SemanticQueryResp result2 = SemanticQueryResp result2 =
queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "); queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数数据集 GROUP BY 部门 ");
assertTrue(result2.isUseCache()); assertTrue(result2.isUseCache());
} }