From 18f268f5907e670853eca201e9e0e1fdbbe0d072 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Thu, 30 May 2024 15:23:52 +0800 Subject: [PATCH] (improvement)(headless)Remove redundant SqlGenStrategy impl --- .../chat/server/parser/MultiTurnParser.java | 4 +- .../supersonic/common/pojo/SysParameter.java | 46 +++++++-------- .../headless/api/pojo/LLMConfig.java | 9 +++ .../parser/llm/OnePassSCSqlGenStrategy.java | 7 +-- .../parser/llm/OnePassSqlGenStrategy.java | 56 ------------------ .../core/chat/parser/llm/PythonLLMProxy.java | 4 +- .../parser/llm/TwoPassSCSqlGenStrategy.java | 11 ++-- .../parser/llm/TwoPassSqlGenStrategy.java | 58 ------------------- .../core/chat/query/llm/s2sql/LLMReq.java | 7 --- .../core/config/OptimizationConfig.java | 55 +++++++++--------- .../server/processor/SqlInfoProcessor.java | 2 +- .../src/main/resources/application-local.yaml | 3 +- .../src/test/resources/application-local.yaml | 41 ++----------- 13 files changed, 81 insertions(+), 222 deletions(-) delete mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGenStrategy.java delete mode 100644 headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGenStrategy.java diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java index 58961db7c..a6aac1fc2 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java @@ -95,13 +95,13 @@ public class MultiTurnParser implements ChatParser { variables.put("histSchema", context.getHistSchema()); Prompt prompt = promptTemplate.apply(variables); - keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage()); + keyPipelineLog.info("MultiTurnParser reqPrompt:{}", prompt.toSystemMessage()); ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig()); Response response = chatLanguageModel.generate(prompt.toSystemMessage()); String result = response.content().text(); - keyPipelineLog.info("model response:{}", result); + keyPipelineLog.info("MultiTurnParser modelResp:{}", result); return response.content().text(); } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java index 34979e691..5e92efaa1 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/SysParameter.java @@ -48,62 +48,62 @@ public class SysParameter { admins = Lists.newArrayList("admin"); //detect config - parameters.add(new Parameter("one.detection.size", "8", + parameters.add(new Parameter("s2.one.detection.size", "8", "一次探测返回结果个数", "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数", "number", "Mapper相关配置")); - parameters.add(new Parameter("one.detection.max.size", "20", + parameters.add(new Parameter("s2.one.detection.max.size", "20", "一次探测前后缀匹配结果返回个数", "单次前后缀匹配返回的结果个数", "number", "Mapper相关配置")); //mapper config - parameters.add(new Parameter("metric.dimension.threshold", "0.3", + parameters.add(new Parameter("s2.metric.dimension.threshold", "0.3", "指标名、维度名文本相似度阈值", "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", "number", "Mapper相关配置")); - parameters.add(new Parameter("metric.dimension.min.threshold", "0.25", + parameters.add(new Parameter("s2.metric.dimension.min.threshold", "0.25", "指标名、维度名最小文本相似度阈值", "指标名、维度名相似度阈值在动态调整中的最低值", "number", "Mapper相关配置")); - parameters.add(new Parameter("dimension.value.threshold", "0.5", + parameters.add(new Parameter("s2.dimension.value.threshold", "0.5", "维度值文本相似度阈值", "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", "number", "Mapper相关配置")); - parameters.add(new Parameter("dimension.value.min.threshold", "0.3", + parameters.add(new Parameter("s2.dimension.value.min.threshold", "0.3", "维度值最小文本相似度阈值", "维度值相似度阈值在动态调整中的最低值", "number", "Mapper相关配置")); //embedding mapper config - parameters.add(new Parameter("embedding.mapper.word.min", + parameters.add(new Parameter("s2.embedding.mapper.word.min", "4", "用于向量召回最小的文本长度", "为提高向量召回效率, 小于该长度的文本不进行向量语义召回", "number", "Mapper相关配置")); - parameters.add(new Parameter("embedding.mapper.word.max", "5", + parameters.add(new Parameter("s2.embedding.mapper.word.max", "5", "用于向量召回最大的文本长度", "为提高向量召回效率, 大于该长度的文本不进行向量语义召回", "number", "Mapper相关配置")); - parameters.add(new Parameter("embedding.mapper.batch", "50", + parameters.add(new Parameter("s2.embedding.mapper.batch", "50", "批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置")); - parameters.add(new Parameter("embedding.mapper.number", "5", + parameters.add(new Parameter("s2.embedding.mapper.number", "5", "批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置")); - parameters.add(new Parameter("embedding.mapper.threshold", + parameters.add(new Parameter("s2.embedding.mapper.threshold", "0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置")); - parameters.add(new Parameter("embedding.mapper.min.threshold", + parameters.add(new Parameter("s2.embedding.mapper.min.threshold", "0.9", "向量召回最小相似度阈值", "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置")); //parser config - Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT", - "S2SQL生成方式", "ONE_PASS_AUTO_COT: 通过思维链方式一步生成sql" - + "\nONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql" - + "\nTWO_PASS_AUTO_COT: 通过思维链方式两步生成sql" + Parameter s2SQLParameter = new Parameter("s2.parser.strategy", + "TWO_PASS_AUTO_COT_SELF_CONSISTENCY", + "LLM解析生成S2SQL策略", + "ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql" + "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql", "list", "Parser相关配置"); - s2SQLParameter.setCandidateValues(Lists.newArrayList("ONE_PASS_AUTO_COT", "ONE_PASS_AUTO_COT_SELF_CONSISTENCY", - "TWO_PASS_AUTO_COT", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY")); + s2SQLParameter.setCandidateValues(Lists.newArrayList( + "ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY")); parameters.add(s2SQLParameter); - parameters.add(new Parameter("s2SQL.linking.value.switch", "true", + parameters.add(new Parameter("s2.s2SQL.linking.value.switch", "true", "是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择", "bool", "Parser相关配置")); - parameters.add(new Parameter("query.text.length.threshold", "10", + parameters.add(new Parameter("s2.query.text.length.threshold", "10", "用户输入文本长短阈值", "文本超过该阈值为长文本", "number", "Parser相关配置")); - parameters.add(new Parameter("short.text.threshold", "0.5", + parameters.add(new Parameter("s2.short.text.threshold", "0.5", "短文本匹配阈值", "由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "Parser相关配置")); - parameters.add(new Parameter("long.text.threshold", "0.8", + parameters.add(new Parameter("s2.long.text.threshold", "0.8", "长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "Parser相关配置")); - parameters.add(new Parameter("parse.show.count", "3", + parameters.add(new Parameter("s2.parse.show-count", "3", "解析结果个数", "前端展示的解析个数", "number", "Parser相关配置")); } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/LLMConfig.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/LLMConfig.java index 711880e5f..7e9af287b 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/LLMConfig.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/LLMConfig.java @@ -27,4 +27,13 @@ public class LLMConfig { this.apiKey = apiKey; this.modelName = modelName; } + + public LLMConfig(String provider, String baseUrl, String apiKey, String modelName, + double temperature) { + this.provider = provider; + this.baseUrl = baseUrl; + this.apiKey = apiKey; + this.modelName = modelName; + this.temperature = temperature; + } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java index 3e32f278a..296796faa 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -25,7 +25,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { @Override public LLMResp generate(LLMReq llmReq) { //1.retriever sqlExamples and generate exampleListPool - keyPipelineLog.info("llmReq:{}", llmReq); + keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq); List> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), optimizationConfig.getText2sqlExampleNum()); @@ -39,12 +39,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> { Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt)) .apply(new HashMap<>()); - keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage()); + keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:{}", prompt.toSystemMessage()); ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); Response response = chatLanguageModel.generate(prompt.toSystemMessage()); String result = response.content().text(); llmResults.add(result); - keyPipelineLog.info("model response:{}", result); + keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:{}", result); } ); //3.format response. @@ -56,7 +56,6 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { .map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList()); Pair> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList); - keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight()); LLMResp result = new LLMResp(); result.setQuery(llmReq.getQueryText()); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGenStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGenStrategy.java deleted file mode 100644 index 77160e047..000000000 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/OnePassSqlGenStrategy.java +++ /dev/null @@ -1,56 +0,0 @@ -package com.tencent.supersonic.headless.core.chat.parser.llm; - -import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp; -import dev.langchain4j.data.message.AiMessage; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.input.Prompt; -import dev.langchain4j.model.input.PromptTemplate; -import dev.langchain4j.model.output.Response; -import lombok.extern.slf4j.Slf4j; -import org.springframework.stereotype.Service; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -@Service -@Slf4j -public class OnePassSqlGenStrategy extends SqlGenStrategy { - - @Override - public LLMResp generate(LLMReq llmReq) { - //1.retriever sqlExamples - keyPipelineLog.info("llmReq:{}", llmReq); - List> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), - optimizationConfig.getText2sqlExampleNum()); - - //2.generator linking and sql prompt by sqlExamples,and generate response. - String promptStr = promptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples); - - Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>()); - keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage()); - ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); - Response response = chatLanguageModel.generate(prompt.toSystemMessage()); - String result = response.content().text(); - keyPipelineLog.info("model response:{}", result); - //3.format response. - String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text()); - String sql = OutputFormat.getSql(response.content().text()); - Map sqlRespMap = new HashMap<>(); - sqlRespMap.put(sql, LLMSqlResp.builder().sqlWeight(1D).fewShots(sqlExamples).build()); - keyPipelineLog.info("schemaLinkStr:{},sqlRespMap:{}", schemaLinkStr, sqlRespMap); - - LLMResp llmResp = new LLMResp(); - llmResp.setQuery(llmReq.getQueryText()); - llmResp.setSqlRespMap(sqlRespMap); - return llmResp; - } - - @Override - public void afterPropertiesSet() { - SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT, this); - } -} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PythonLLMProxy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PythonLLMProxy.java index be0c3a1f2..7c5dfcd73 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PythonLLMProxy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/PythonLLMProxy.java @@ -32,7 +32,7 @@ public class PythonLLMProxy implements LLMProxy { public LLMResp text2sql(LLMReq llmReq) { long startTime = System.currentTimeMillis(); log.info("requestLLM request, llmReq:{}", llmReq); - keyPipelineLog.info("llmReq:{}", llmReq); + keyPipelineLog.info("PythonLLMProxy llmReq:{}", llmReq); try { LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class); @@ -47,7 +47,7 @@ public class PythonLLMProxy implements LLMProxy { LLMResp llmResp = responseEntity.getBody(); log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}", System.currentTimeMillis() - startTime, url, entity, llmResp); - keyPipelineLog.info("LLMResp:{}", llmResp); + keyPipelineLog.info("PythonLLMProxy llmResp:{}", llmResp); if (MapUtils.isEmpty(llmResp.getSqlRespMap())) { llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight())); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGenStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGenStrategy.java index 0b51fd1ad..9d4ad6550 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGenStrategy.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSCSqlGenStrategy.java @@ -23,7 +23,7 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy { @Override public LLMResp generate(LLMReq llmReq) { //1.retriever sqlExamples and generate exampleListPool - keyPipelineLog.info("llmReq:{}", llmReq); + keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq); List> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), optimizationConfig.getText2sqlExampleNum()); @@ -37,10 +37,10 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy { linkingPromptPool.parallelStream().forEach( linkingPrompt -> { Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>()); - keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage()); + keyPipelineLog.info("TwoPassSCSqlGenStrategy step one reqPrompt:{}", prompt.toSystemMessage()); Response linkingResult = chatLanguageModel.generate(prompt.toSystemMessage()); String result = linkingResult.content().text(); - keyPipelineLog.info("step one model response:{}", result); + keyPipelineLog.info("TwoPassSCSqlGenStrategy step one modelResp:{}", result); linkingResults.add(OutputFormat.getSchemaLink(result)); } ); @@ -51,15 +51,14 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy { List sqlTaskPool = new CopyOnWriteArrayList<>(); sqlPromptPool.parallelStream().forEach(sqlPrompt -> { Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>()); - keyPipelineLog.info("step two request prompt:{}", linkingPrompt.toSystemMessage()); + keyPipelineLog.info("TwoPassSCSqlGenStrategy step two reqPrompt:{}", linkingPrompt.toSystemMessage()); Response sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage()); String result = sqlResult.content().text(); - keyPipelineLog.info("step two model response:{}", result); + keyPipelineLog.info("TwoPassSCSqlGenStrategy step two modelResp:{}", result); sqlTaskPool.add(result); }); //4.format response. Pair> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool); - keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight()); LLMResp llmResp = new LLMResp(); llmResp.setQuery(llmReq.getQueryText()); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGenStrategy.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGenStrategy.java deleted file mode 100644 index 2310b67f2..000000000 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/parser/llm/TwoPassSqlGenStrategy.java +++ /dev/null @@ -1,58 +0,0 @@ -package com.tencent.supersonic.headless.core.chat.parser.llm; - -import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType; -import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp; -import dev.langchain4j.data.message.AiMessage; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.input.Prompt; -import dev.langchain4j.model.input.PromptTemplate; -import dev.langchain4j.model.output.Response; -import lombok.extern.slf4j.Slf4j; -import org.springframework.stereotype.Service; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -@Service -@Slf4j -public class TwoPassSqlGenStrategy extends SqlGenStrategy { - - @Override - public LLMResp generate(LLMReq llmReq) { - keyPipelineLog.info("llmReq:{}", llmReq); - List> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), - optimizationConfig.getText2sqlExampleNum()); - - String linkingPromptStr = promptGenerator.generateLinkingPrompt(llmReq, sqlExamples); - - Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>()); - keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage()); - ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); - Response response = chatLanguageModel.generate(prompt.toSystemMessage()); - keyPipelineLog.info("step one model response:{}", response.content().text()); - String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text()); - String generateSqlPrompt = promptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples); - - Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>()); - keyPipelineLog.info("step two request prompt:{}", sqlPrompt.toSystemMessage()); - Response sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage()); - String result = sqlResult.content().text(); - keyPipelineLog.info("step two model response:{}", result); - Map sqlMap = new HashMap<>(); - sqlMap.put(result, 1D); - keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap); - - LLMResp llmResp = new LLMResp(); - llmResp.setQuery(llmReq.getQueryText()); - llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMap)); - return llmResp; - } - - @Override - public void afterPropertiesSet() { - SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT, this); - } -} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java index 9fe011da6..5dfd89c6c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/query/llm/s2sql/LLMReq.java @@ -68,16 +68,9 @@ public class LLMReq { } public enum SqlGenType { - - ONE_PASS_AUTO_COT("1_pass_auto_cot"), - ONE_PASS_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"), - - TWO_PASS_AUTO_COT("2_pass_auto_cot"), - TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency"); - private String name; SqlGenType(String name) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java index 52c179064..a940ec3d2 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/config/OptimizationConfig.java @@ -16,6 +16,7 @@ public class OptimizationConfig { @Value("${s2.one.detection.size:8}") private Integer oneDetectionSize; + @Value("${s2.one.detection.max.size:20}") private Integer oneDetectionMaxSize; @@ -67,19 +68,19 @@ public class OptimizationConfig { @Value("${s2.parser.linking.value.switch:true}") private boolean useLinkingValueSwitch; - @Value("${s2.parser.generation:TWO_PASS_AUTO_COT}") + @Value("${s2.parser.strategy:TWO_PASS_AUTO_COT_SELF_CONSISTENCY}") private LLMReq.SqlGenType sqlGenType; @Value("${s2.parser.use.switch:true}") private boolean useS2SqlSwitch; - @Value("${s2.parser.exemplar-recall.num:15}") + @Value("${s2.parser.exemplar-recall.number:15}") private int text2sqlExampleNum; - @Value("${s2.parser.few-shot.num:10}") + @Value("${s2.parser.few-shot.number:5}") private int text2sqlFewShotsNum; - @Value("${s2.parser.self-consistency.num:5}") + @Value("${s2.parser.self-consistency.number:5}") private int text2sqlSelfConsistencyNum; @Value("${s2.parser.show-count:3}") @@ -89,83 +90,83 @@ public class OptimizationConfig { private SysParameterService sysParameterService; public Integer getOneDetectionSize() { - return convertValue("one.detection.size", Integer.class, oneDetectionSize); + return convertValue("s2.one.detection.size", Integer.class, oneDetectionSize); } public Integer getOneDetectionMaxSize() { - return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize); + return convertValue("s2.one.detection.max.size", Integer.class, oneDetectionMaxSize); } public Double getMetricDimensionMinThresholdConfig() { - return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig); + return convertValue("s2.metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig); } public Double getMetricDimensionThresholdConfig() { - return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig); + return convertValue("s2.metric.dimension.threshold", Double.class, metricDimensionThresholdConfig); } public Double getDimensionValueMinThresholdConfig() { - return convertValue("dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig); + return convertValue("s2.dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig); } public Double getDimensionValueThresholdConfig() { - return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig); + return convertValue("s2.dimension.value.threshold", Double.class, dimensionValueThresholdConfig); } public Double getLongTextThreshold() { - return convertValue("long.text.threshold", Double.class, longTextThreshold); + return convertValue("s2.long.text.threshold", Double.class, longTextThreshold); } public Double getShortTextThreshold() { - return convertValue("short.text.threshold", Double.class, shortTextThreshold); + return convertValue("s2.short.text.threshold", Double.class, shortTextThreshold); } public Integer getQueryTextLengthThreshold() { - return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold); - } - - public boolean isUseS2SqlSwitch() { - return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch); + return convertValue("s2.query.text.length.threshold", Integer.class, queryTextLengthThreshold); } public Integer getEmbeddingMapperWordMin() { - return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin); + return convertValue("s2.embedding.mapper.word.min", Integer.class, embeddingMapperWordMin); } public Integer getEmbeddingMapperWordMax() { - return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax); + return convertValue("s2.embedding.mapper.word.max", Integer.class, embeddingMapperWordMax); } public Integer getEmbeddingMapperBatch() { - return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch); + return convertValue("s2.embedding.mapper.batch", Integer.class, embeddingMapperBatch); } public Integer getEmbeddingMapperNumber() { - return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber); + return convertValue("s2.embedding.mapper.number", Integer.class, embeddingMapperNumber); } public Integer getEmbeddingMapperRoundNumber() { - return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber); + return convertValue("s2.embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber); } public Double getEmbeddingMapperMinThreshold() { - return convertValue("embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold); + return convertValue("s2.embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold); } public Double getEmbeddingMapperThreshold() { - return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold); + return convertValue("s2.embedding.mapper.threshold", Double.class, embeddingMapperThreshold); + } + + public boolean isUseS2SqlSwitch() { + return convertValue("s2.parser.use.switch", Boolean.class, useS2SqlSwitch); } public boolean isUseLinkingValueSwitch() { - return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch); + return convertValue("s2.parser.linking.value.switch", Boolean.class, useLinkingValueSwitch); } public LLMReq.SqlGenType getSqlGenType() { - return convertValue("s2SQL.generation", LLMReq.SqlGenType.class, sqlGenType); + return convertValue("s2.parser.strategy", LLMReq.SqlGenType.class, sqlGenType); } public Integer getParseShowCount() { - return convertValue("parse.show.count", Integer.class, parseShowCount); + return convertValue("s2.parse.show-count", Integer.class, parseShowCount); } public T convertValue(String paramName, Class targetType, T defaultValue) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/SqlInfoProcessor.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/SqlInfoProcessor.java index 361bf761e..74c686a76 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/SqlInfoProcessor.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/SqlInfoProcessor.java @@ -75,7 +75,7 @@ public class SqlInfoProcessor implements ResultProcessor { } SqlInfo sqlInfo = parseInfo.getSqlInfo(); if (semanticQuery instanceof LLMSqlQuery) { - keyPipelineLog.info("\nparsedS2SQL:{}\ncorrectedS2SQL:{}\nfinalSQL:{}", sqlInfo.getS2SQL(), + keyPipelineLog.info("SqlInfoProcessor results:\nParsed S2SQL:{}\nCorrected S2SQL:{}\nFinal SQL:{}", sqlInfo.getS2SQL(), sqlInfo.getCorrectS2SQL(), explainSql); } sqlInfo.setQuerySQL(explainSql); diff --git a/launchers/standalone/src/main/resources/application-local.yaml b/launchers/standalone/src/main/resources/application-local.yaml index 5fbd67c00..68927b5dd 100644 --- a/launchers/standalone/src/main/resources/application-local.yaml +++ b/launchers/standalone/src/main/resources/application-local.yaml @@ -41,12 +41,13 @@ s2: parser: url: ${s2.pyllm.url} + strategy: TWO_PASS_AUTO_COT_SELF_CONSISTENCY exemplar-recall: number: 10 few-shot: number: 5 self-consistency: - number: 5 + number: 1 multi-turn: enable: false diff --git a/launchers/standalone/src/test/resources/application-local.yaml b/launchers/standalone/src/test/resources/application-local.yaml index 510c4dc69..5d4deb3a6 100644 --- a/launchers/standalone/src/test/resources/application-local.yaml +++ b/launchers/standalone/src/test/resources/application-local.yaml @@ -36,17 +36,14 @@ logging: dev.ai4j.openai4j: DEBUG s2: - pyllm: - url: http://127.0.0.1:9092 - parser: - url: ${s2.pyllm.url} + strategy: TWO_PASS_AUTO_COT_SELF_CONSISTENCY exemplar-recall: - number: 10 + number: 5 few-shot: - number: 5 + number: 1 self-consistency: - number: 5 + number: 1 multi-turn: enable: false @@ -54,17 +51,9 @@ s2: additional: information: true - functionCall: - url: ${s2.pyllm.url} - - embedding: - url: ${s2.pyllm.url} - persistent: - path: /tmp - demo: names: S2VisitsDemo,S2ArtistDemo - enableLLM: true + enableLLM: false schema: cache: @@ -86,22 +75,4 @@ s2: #2.embedding-model #2.1 in_memory(default) embedding-model: - provider: in_process - # inProcess: - # modelPath: /data/model.onnx - # vocabularyPath: /data/onnx_vocab.txt - # shibing624/text2vec-base-chinese - #2.2 open_ai - # embedding-model: - # provider: open_ai - # openai: - # api-key: api_key - # modelName: all-minilm-l6-v2.onnx - - #2.2 hugging_face - # embedding-model: - # provider: hugging_face - # hugging-face: - # access-token: hg_access_token - # model-id: sentence-transformers/all-MiniLM-L6-v2 - # timeout: 1h \ No newline at end of file + provider: in_process \ No newline at end of file