diff --git a/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java b/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java index 088bab1f9..c55f5a3c8 100644 --- a/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java @@ -22,15 +22,17 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean { @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { - OpenAiChatModel.OpenAiChatModelBuilder openAiChatModelBuilder = OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl()) - .modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt()) - .apiVersion(modelConfig.getApiVersion()).temperature(modelConfig.getTemperature()) - .topP(modelConfig.getTopP()).maxRetries(modelConfig.getMaxRetries()) + OpenAiChatModel.OpenAiChatModelBuilder openAiChatModelBuilder = OpenAiChatModel.builder() + .baseUrl(modelConfig.getBaseUrl()).modelName(modelConfig.getModelName()) + .apiKey(modelConfig.keyDecrypt()).apiVersion(modelConfig.getApiVersion()) + .temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP()) + .maxRetries(modelConfig.getMaxRetries()) .timeout(Duration.ofSeconds(modelConfig.getTimeOut())) .logRequests(modelConfig.getLogRequests()) .logResponses(modelConfig.getLogResponses()); if (modelConfig.getJsonFormat()) { - openAiChatModelBuilder.strictJsonSchema(true).responseFormat(modelConfig.getJsonFormatType()); + openAiChatModelBuilder.strictJsonSchema(true) + .responseFormat(modelConfig.getJsonFormatType()); } return openAiChatModelBuilder.build(); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index 785d0643e..b6e702e9d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -87,7 +87,8 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { ChatModelConfig chatModelConfig = chatApp.getChatModelConfig(); if (!StringUtils.isBlank(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE))) { chatModelConfig.setJsonFormat(true); - chatModelConfig.setJsonFormatType(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE)); + chatModelConfig + .setJsonFormatType(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE)); } ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatModelConfig); SemanticSqlExtractor extractor = diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java index 3a29e9d47..ca444469b 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java @@ -82,6 +82,7 @@ public class BaseTest extends BaseApplication { protected SemanticQueryReq buildQuerySqlReq(String sql) { QuerySqlReq querySqlCmd = new QuerySqlReq(); querySqlCmd.setSql(sql); + querySqlCmd.getSqlInfo().setCorrectedS2SQL(sql); querySqlCmd.setModelIds(DataUtils.getMetricAgentIModelIds()); return querySqlCmd; } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryBySqlTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryBySqlTest.java index a12daa7ea..956885d89 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryBySqlTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryBySqlTest.java @@ -33,8 +33,7 @@ public class QueryBySqlTest extends BaseTest { @Test @SetSystemProperty(key = "s2.test", value = "true") public void testSumQuery() throws Exception { - SemanticQueryResp semanticQueryResp = - queryBySql("SELECT SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 "); + SemanticQueryResp semanticQueryResp = queryBySql("SELECT SUM(访问次数) AS 总访问次数 FROM 超音数数据集 "); assertEquals(1, semanticQueryResp.getColumns().size()); QueryColumn queryColumn = semanticQueryResp.getColumns().get(0); @@ -45,7 +44,7 @@ public class QueryBySqlTest extends BaseTest { @Test public void testGroupByQuery() throws Exception { SemanticQueryResp result = - queryBySql("SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "); + queryBySql("SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数数据集 GROUP BY 部门 "); assertEquals(2, result.getColumns().size()); QueryColumn firstColumn = result.getColumns().get(0); QueryColumn secondColumn = result.getColumns().get(1); @@ -56,8 +55,8 @@ public class QueryBySqlTest extends BaseTest { @Test public void testFilterQuery() throws Exception { - SemanticQueryResp result = queryBySql( - "SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 WHERE 部门 ='HR' GROUP BY 部门 "); + SemanticQueryResp result = + queryBySql("SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数数据集 WHERE 部门 ='HR' GROUP BY 部门 "); assertEquals(2, result.getColumns().size()); QueryColumn firstColumn = result.getColumns().get(0); QueryColumn secondColumn = result.getColumns().get(1); @@ -71,8 +70,7 @@ public class QueryBySqlTest extends BaseTest { public void testDateSumQuery() throws Exception { String startDate = now().plusDays(-365).toString(); String endDate = now().plusDays(0).toString(); - String sql = - "SELECT SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 WHERE 数据日期 >= '%s' AND 数据日期 <= '%s' "; + String sql = "SELECT SUM(访问次数) AS 总访问次数 FROM 超音数数据集 WHERE 数据日期 >= '%s' AND 数据日期 <= '%s' "; SemanticQueryResp semanticQueryResp = queryBySql(String.format(sql, startDate, endDate)); assertEquals(1, semanticQueryResp.getColumns().size()); QueryColumn queryColumn = semanticQueryResp.getColumns().get(0); @@ -82,9 +80,9 @@ public class QueryBySqlTest extends BaseTest { @Test public void testCacheQuery() throws Exception { - queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "); + queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数数据集 GROUP BY 部门 "); SemanticQueryResp result2 = - queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "); + queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数数据集 GROUP BY 部门 "); assertTrue(result2.isUseCache()); }