mirror of
https://github.com/tencentmusic/supersonic.git
synced 2026-04-29 04:14:20 +08:00
[fix]Fix unit test cases.
This commit is contained in:
@@ -22,15 +22,17 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||||
OpenAiChatModel.OpenAiChatModelBuilder openAiChatModelBuilder = OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
OpenAiChatModel.OpenAiChatModelBuilder openAiChatModelBuilder = OpenAiChatModel.builder()
|
||||||
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
|
.baseUrl(modelConfig.getBaseUrl()).modelName(modelConfig.getModelName())
|
||||||
.apiVersion(modelConfig.getApiVersion()).temperature(modelConfig.getTemperature())
|
.apiKey(modelConfig.keyDecrypt()).apiVersion(modelConfig.getApiVersion())
|
||||||
.topP(modelConfig.getTopP()).maxRetries(modelConfig.getMaxRetries())
|
.temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP())
|
||||||
|
.maxRetries(modelConfig.getMaxRetries())
|
||||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||||
.logRequests(modelConfig.getLogRequests())
|
.logRequests(modelConfig.getLogRequests())
|
||||||
.logResponses(modelConfig.getLogResponses());
|
.logResponses(modelConfig.getLogResponses());
|
||||||
if (modelConfig.getJsonFormat()) {
|
if (modelConfig.getJsonFormat()) {
|
||||||
openAiChatModelBuilder.strictJsonSchema(true).responseFormat(modelConfig.getJsonFormatType());
|
openAiChatModelBuilder.strictJsonSchema(true)
|
||||||
|
.responseFormat(modelConfig.getJsonFormatType());
|
||||||
}
|
}
|
||||||
return openAiChatModelBuilder.build();
|
return openAiChatModelBuilder.build();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -87,7 +87,8 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
ChatModelConfig chatModelConfig = chatApp.getChatModelConfig();
|
ChatModelConfig chatModelConfig = chatApp.getChatModelConfig();
|
||||||
if (!StringUtils.isBlank(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE))) {
|
if (!StringUtils.isBlank(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE))) {
|
||||||
chatModelConfig.setJsonFormat(true);
|
chatModelConfig.setJsonFormat(true);
|
||||||
chatModelConfig.setJsonFormatType(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE));
|
chatModelConfig
|
||||||
|
.setJsonFormatType(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE));
|
||||||
}
|
}
|
||||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatModelConfig);
|
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatModelConfig);
|
||||||
SemanticSqlExtractor extractor =
|
SemanticSqlExtractor extractor =
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ public class BaseTest extends BaseApplication {
|
|||||||
protected SemanticQueryReq buildQuerySqlReq(String sql) {
|
protected SemanticQueryReq buildQuerySqlReq(String sql) {
|
||||||
QuerySqlReq querySqlCmd = new QuerySqlReq();
|
QuerySqlReq querySqlCmd = new QuerySqlReq();
|
||||||
querySqlCmd.setSql(sql);
|
querySqlCmd.setSql(sql);
|
||||||
|
querySqlCmd.getSqlInfo().setCorrectedS2SQL(sql);
|
||||||
querySqlCmd.setModelIds(DataUtils.getMetricAgentIModelIds());
|
querySqlCmd.setModelIds(DataUtils.getMetricAgentIModelIds());
|
||||||
return querySqlCmd;
|
return querySqlCmd;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -33,8 +33,7 @@ public class QueryBySqlTest extends BaseTest {
|
|||||||
@Test
|
@Test
|
||||||
@SetSystemProperty(key = "s2.test", value = "true")
|
@SetSystemProperty(key = "s2.test", value = "true")
|
||||||
public void testSumQuery() throws Exception {
|
public void testSumQuery() throws Exception {
|
||||||
SemanticQueryResp semanticQueryResp =
|
SemanticQueryResp semanticQueryResp = queryBySql("SELECT SUM(访问次数) AS 总访问次数 FROM 超音数数据集 ");
|
||||||
queryBySql("SELECT SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 ");
|
|
||||||
|
|
||||||
assertEquals(1, semanticQueryResp.getColumns().size());
|
assertEquals(1, semanticQueryResp.getColumns().size());
|
||||||
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
|
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
|
||||||
@@ -45,7 +44,7 @@ public class QueryBySqlTest extends BaseTest {
|
|||||||
@Test
|
@Test
|
||||||
public void testGroupByQuery() throws Exception {
|
public void testGroupByQuery() throws Exception {
|
||||||
SemanticQueryResp result =
|
SemanticQueryResp result =
|
||||||
queryBySql("SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
|
queryBySql("SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数数据集 GROUP BY 部门 ");
|
||||||
assertEquals(2, result.getColumns().size());
|
assertEquals(2, result.getColumns().size());
|
||||||
QueryColumn firstColumn = result.getColumns().get(0);
|
QueryColumn firstColumn = result.getColumns().get(0);
|
||||||
QueryColumn secondColumn = result.getColumns().get(1);
|
QueryColumn secondColumn = result.getColumns().get(1);
|
||||||
@@ -56,8 +55,8 @@ public class QueryBySqlTest extends BaseTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFilterQuery() throws Exception {
|
public void testFilterQuery() throws Exception {
|
||||||
SemanticQueryResp result = queryBySql(
|
SemanticQueryResp result =
|
||||||
"SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 WHERE 部门 ='HR' GROUP BY 部门 ");
|
queryBySql("SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数数据集 WHERE 部门 ='HR' GROUP BY 部门 ");
|
||||||
assertEquals(2, result.getColumns().size());
|
assertEquals(2, result.getColumns().size());
|
||||||
QueryColumn firstColumn = result.getColumns().get(0);
|
QueryColumn firstColumn = result.getColumns().get(0);
|
||||||
QueryColumn secondColumn = result.getColumns().get(1);
|
QueryColumn secondColumn = result.getColumns().get(1);
|
||||||
@@ -71,8 +70,7 @@ public class QueryBySqlTest extends BaseTest {
|
|||||||
public void testDateSumQuery() throws Exception {
|
public void testDateSumQuery() throws Exception {
|
||||||
String startDate = now().plusDays(-365).toString();
|
String startDate = now().plusDays(-365).toString();
|
||||||
String endDate = now().plusDays(0).toString();
|
String endDate = now().plusDays(0).toString();
|
||||||
String sql =
|
String sql = "SELECT SUM(访问次数) AS 总访问次数 FROM 超音数数据集 WHERE 数据日期 >= '%s' AND 数据日期 <= '%s' ";
|
||||||
"SELECT SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 WHERE 数据日期 >= '%s' AND 数据日期 <= '%s' ";
|
|
||||||
SemanticQueryResp semanticQueryResp = queryBySql(String.format(sql, startDate, endDate));
|
SemanticQueryResp semanticQueryResp = queryBySql(String.format(sql, startDate, endDate));
|
||||||
assertEquals(1, semanticQueryResp.getColumns().size());
|
assertEquals(1, semanticQueryResp.getColumns().size());
|
||||||
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
|
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
|
||||||
@@ -82,9 +80,9 @@ public class QueryBySqlTest extends BaseTest {
|
|||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testCacheQuery() throws Exception {
|
public void testCacheQuery() throws Exception {
|
||||||
queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
|
queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数数据集 GROUP BY 部门 ");
|
||||||
SemanticQueryResp result2 =
|
SemanticQueryResp result2 =
|
||||||
queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
|
queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数数据集 GROUP BY 部门 ");
|
||||||
assertTrue(result2.isUseCache());
|
assertTrue(result2.isUseCache());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user