diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java index 1f43151c6..51b4ac39d 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/config/OptimizationConfig.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.chat.config; +import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import com.tencent.supersonic.common.service.SysParameterService; import lombok.Data; import lombok.extern.slf4j.Slf4j; @@ -57,6 +58,9 @@ public class OptimizationConfig { @Value("${s2SQL.linking.value.switch:true}") private boolean useLinkingValueSwitch; + @Value("${s2SQL.generation:TWO_PASS_AUTO_COT}") + private SqlGenerationMode sqlGenerationMode; + @Value("${s2SQL.use.switch:true}") private boolean useS2SqlSwitch; @@ -139,6 +143,10 @@ public class OptimizationConfig { return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch); } + public SqlGenerationMode getSqlGenerationMode() { + return convertValue("s2SQL.generation", SqlGenerationMode.class, sqlGenerationMode); + } + public T convertValue(String paramName, Class targetType, T defaultValue) { try { String value = sysParameterService.getSysParameter().getParameterByName(paramName); @@ -151,6 +159,8 @@ public class OptimizationConfig { return targetType.cast(Integer.parseInt(value)); } else if (targetType == Boolean.class) { return targetType.cast(Boolean.parseBoolean(value)); + } else if (targetType == SqlGenerationMode.class) { + return targetType.cast(SqlGenerationMode.valueOf(value)); } } catch (Exception e) { log.error("convertValue", e); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMRequestService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMRequestService.java index 09b45d89c..391d86893 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMRequestService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/LLMRequestService.java @@ -132,6 +132,7 @@ public class LLMRequestService { currentDate = DateUtils.getBeforeDate(0); } llmReq.setCurrentDate(currentDate); + llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode()); return llmReq; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGenerationFactory.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGenerationFactory.java index cd6fd0bc2..6aa0586bc 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGenerationFactory.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/sql/llm/SqlGenerationFactory.java @@ -5,9 +5,6 @@ import com.tencent.supersonic.chat.query.llm.s2sql.LLMReq.SqlGenerationMode; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -/** - * Sql generation factory - */ public class SqlGenerationFactory { private static Map sqlGenerationMap = new ConcurrentHashMap<>(); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java index 2be690891..f4a78ab2a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/s2sql/LLMReq.java @@ -19,7 +19,7 @@ public class LLMReq { private String priorExts; - private SqlGenerationMode sqlGenerationMode = SqlGenerationMode.TWO_PASS_AUTO_COT; + private SqlGenerationMode sqlGenerationMode; @Data public static class ElementValue { diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java index d21b7105b..f257d1ce0 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java @@ -87,10 +87,10 @@ public class SysParameter { parameters.add(new Parameter("llm.temperature", "0.0", "温度值", "number", "Parser相关配置")); - Parameter s2SQLParameter = new Parameter("s2SQL.generation", "2_pass_auto_cot_self_consistency", + Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT", "S2SQL生成方式", "list", "Parser相关配置"); - s2SQLParameter.setCandidateValues(Lists.newArrayList("1_pass_auto_cot", "1_pass_auto_cot_self_consistency", - "2_pass_auto_cot", "2_pass_auto_cot_self_consistency")); + s2SQLParameter.setCandidateValues(Lists.newArrayList("ONE_PASS_AUTO_COT", "ONE_PASS_AUTO_COT_SELF_CONSISTENCY", + "TWO_PASS_AUTO_COT", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY")); parameters.add(s2SQLParameter); parameters.add(new Parameter("s2SQL.linking.value.switch", "true", "是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择",