(improvement)(headless)Remove redundant SqlGenStrategy impl

This commit is contained in:
jerryjzhang
2024-05-30 15:23:52 +08:00
parent 3b09a5c0ed
commit 18f268f590
13 changed files with 81 additions and 222 deletions

View File

@@ -95,13 +95,13 @@ public class MultiTurnParser implements ChatParser {
variables.put("histSchema", context.getHistSchema()); variables.put("histSchema", context.getHistSchema());
Prompt prompt = promptTemplate.apply(variables); Prompt prompt = promptTemplate.apply(variables);
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage()); keyPipelineLog.info("MultiTurnParser reqPrompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig()); ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text(); String result = response.content().text();
keyPipelineLog.info("model response:{}", result); keyPipelineLog.info("MultiTurnParser modelResp:{}", result);
return response.content().text(); return response.content().text();
} }

View File

@@ -48,62 +48,62 @@ public class SysParameter {
admins = Lists.newArrayList("admin"); admins = Lists.newArrayList("admin");
//detect config //detect config
parameters.add(new Parameter("one.detection.size", "8", parameters.add(new Parameter("s2.one.detection.size", "8",
"一次探测返回结果个数", "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数", "一次探测返回结果个数", "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数",
"number", "Mapper相关配置")); "number", "Mapper相关配置"));
parameters.add(new Parameter("one.detection.max.size", "20", parameters.add(new Parameter("s2.one.detection.max.size", "20",
"一次探测前后缀匹配结果返回个数", "单次前后缀匹配返回的结果个数", "number", "Mapper相关配置")); "一次探测前后缀匹配结果返回个数", "单次前后缀匹配返回的结果个数", "number", "Mapper相关配置"));
//mapper config //mapper config
parameters.add(new Parameter("metric.dimension.threshold", "0.3", parameters.add(new Parameter("s2.metric.dimension.threshold", "0.3",
"指标名、维度名文本相似度阈值", "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", "指标名、维度名文本相似度阈值", "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
"number", "Mapper相关配置")); "number", "Mapper相关配置"));
parameters.add(new Parameter("metric.dimension.min.threshold", "0.25", parameters.add(new Parameter("s2.metric.dimension.min.threshold", "0.25",
"指标名、维度名最小文本相似度阈值", "指标名、维度名相似度阈值在动态调整中的最低值", "指标名、维度名最小文本相似度阈值", "指标名、维度名相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置")); "number", "Mapper相关配置"));
parameters.add(new Parameter("dimension.value.threshold", "0.5", parameters.add(new Parameter("s2.dimension.value.threshold", "0.5",
"维度值文本相似度阈值", "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", "维度值文本相似度阈值", "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
"number", "Mapper相关配置")); "number", "Mapper相关配置"));
parameters.add(new Parameter("dimension.value.min.threshold", "0.3", parameters.add(new Parameter("s2.dimension.value.min.threshold", "0.3",
"维度值最小文本相似度阈值", "维度值相似度阈值在动态调整中的最低值", "维度值最小文本相似度阈值", "维度值相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置")); "number", "Mapper相关配置"));
//embedding mapper config //embedding mapper config
parameters.add(new Parameter("embedding.mapper.word.min", parameters.add(new Parameter("s2.embedding.mapper.word.min",
"4", "用于向量召回最小的文本长度", "为提高向量召回效率, 小于该长度的文本不进行向量语义召回", "number", "Mapper相关配置")); "4", "用于向量召回最小的文本长度", "为提高向量召回效率, 小于该长度的文本不进行向量语义召回", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.word.max", "5", parameters.add(new Parameter("s2.embedding.mapper.word.max", "5",
"用于向量召回最大的文本长度", "为提高向量召回效率, 大于该长度的文本不进行向量语义召回", "number", "Mapper相关配置")); "用于向量召回最大的文本长度", "为提高向量召回效率, 大于该长度的文本不进行向量语义召回", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.batch", "50", parameters.add(new Parameter("s2.embedding.mapper.batch", "50",
"批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置")); "批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.number", "5", parameters.add(new Parameter("s2.embedding.mapper.number", "5",
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置")); "批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.threshold", parameters.add(new Parameter("s2.embedding.mapper.threshold",
"0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置")); "0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.min.threshold", parameters.add(new Parameter("s2.embedding.mapper.min.threshold",
"0.9", "向量召回最小相似度阈值", "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置")); "0.9", "向量召回最小相似度阈值", "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"));
//parser config //parser config
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT", Parameter s2SQLParameter = new Parameter("s2.parser.strategy",
"S2SQL生成方式", "ONE_PASS_AUTO_COT: 通过思维链方式一步生成sql" "TWO_PASS_AUTO_COT_SELF_CONSISTENCY",
+ "\nONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql" "LLM解析生成S2SQL策略",
+ "\nTWO_PASS_AUTO_COT: 通过思维链方式步生成sql" "ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式步生成sql"
+ "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql", "list", "Parser相关配置"); + "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql", "list", "Parser相关配置");
s2SQLParameter.setCandidateValues(Lists.newArrayList("ONE_PASS_AUTO_COT", "ONE_PASS_AUTO_COT_SELF_CONSISTENCY", s2SQLParameter.setCandidateValues(Lists.newArrayList(
"TWO_PASS_AUTO_COT", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY")); "ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
parameters.add(s2SQLParameter); parameters.add(s2SQLParameter);
parameters.add(new Parameter("s2SQL.linking.value.switch", "true", parameters.add(new Parameter("s2.s2SQL.linking.value.switch", "true",
"是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择", "是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择",
"bool", "Parser相关配置")); "bool", "Parser相关配置"));
parameters.add(new Parameter("query.text.length.threshold", "10", parameters.add(new Parameter("s2.query.text.length.threshold", "10",
"用户输入文本长短阈值", "文本超过该阈值为长文本", "number", "Parser相关配置")); "用户输入文本长短阈值", "文本超过该阈值为长文本", "number", "Parser相关配置"));
parameters.add(new Parameter("short.text.threshold", "0.5", parameters.add(new Parameter("s2.short.text.threshold", "0.5",
"短文本匹配阈值", "由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "短文本匹配阈值", "由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置")); "number", "Parser相关配置"));
parameters.add(new Parameter("long.text.threshold", "0.8", parameters.add(new Parameter("s2.long.text.threshold", "0.8",
"长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置")); "number", "Parser相关配置"));
parameters.add(new Parameter("parse.show.count", "3", parameters.add(new Parameter("s2.parse.show-count", "3",
"解析结果个数", "前端展示的解析个数", "解析结果个数", "前端展示的解析个数",
"number", "Parser相关配置")); "number", "Parser相关配置"));
} }

View File

@@ -27,4 +27,13 @@ public class LLMConfig {
this.apiKey = apiKey; this.apiKey = apiKey;
this.modelName = modelName; this.modelName = modelName;
} }
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName,
double temperature) {
this.provider = provider;
this.baseUrl = baseUrl;
this.apiKey = apiKey;
this.modelName = modelName;
this.temperature = temperature;
}
} }

View File

@@ -25,7 +25,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
@Override @Override
public LLMResp generate(LLMReq llmReq) { public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples and generate exampleListPool //1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("llmReq:{}", llmReq); keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum()); optimizationConfig.getText2sqlExampleNum());
@@ -39,12 +39,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> { linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt)) Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
.apply(new HashMap<>()); .apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage()); keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text(); String result = response.content().text();
llmResults.add(result); llmResults.add(result);
keyPipelineLog.info("model response:{}", result); keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:{}", result);
} }
); );
//3.format response. //3.format response.
@@ -56,7 +56,6 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList()); .map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList); Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
LLMResp result = new LLMResp(); LLMResp result = new LLMResp();
result.setQuery(llmReq.getQueryText()); result.setQuery(llmReq.getQueryText());

View File

@@ -1,56 +0,0 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
@Slf4j
public class OnePassSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples
keyPipelineLog.info("llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
//2.generator linking and sql prompt by sqlExamples,and generate response.
String promptStr = promptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
keyPipelineLog.info("model response:{}", result);
//3.format response.
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
String sql = OutputFormat.getSql(response.content().text());
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
sqlRespMap.put(sql, LLMSqlResp.builder().sqlWeight(1D).fewShots(sqlExamples).build());
keyPipelineLog.info("schemaLinkStr:{},sqlRespMap:{}", schemaLinkStr, sqlRespMap);
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(sqlRespMap);
return llmResp;
}
@Override
public void afterPropertiesSet() {
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT, this);
}
}

View File

@@ -32,7 +32,7 @@ public class PythonLLMProxy implements LLMProxy {
public LLMResp text2sql(LLMReq llmReq) { public LLMResp text2sql(LLMReq llmReq) {
long startTime = System.currentTimeMillis(); long startTime = System.currentTimeMillis();
log.info("requestLLM request, llmReq:{}", llmReq); log.info("requestLLM request, llmReq:{}", llmReq);
keyPipelineLog.info("llmReq:{}", llmReq); keyPipelineLog.info("PythonLLMProxy llmReq:{}", llmReq);
try { try {
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class); LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
@@ -47,7 +47,7 @@ public class PythonLLMProxy implements LLMProxy {
LLMResp llmResp = responseEntity.getBody(); LLMResp llmResp = responseEntity.getBody();
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}", log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url, entity, llmResp); System.currentTimeMillis() - startTime, url, entity, llmResp);
keyPipelineLog.info("LLMResp:{}", llmResp); keyPipelineLog.info("PythonLLMProxy llmResp:{}", llmResp);
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) { if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight())); llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));

View File

@@ -23,7 +23,7 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
@Override @Override
public LLMResp generate(LLMReq llmReq) { public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples and generate exampleListPool //1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("llmReq:{}", llmReq); keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(), List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum()); optimizationConfig.getText2sqlExampleNum());
@@ -37,10 +37,10 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
linkingPromptPool.parallelStream().forEach( linkingPromptPool.parallelStream().forEach(
linkingPrompt -> { linkingPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>()); Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage()); keyPipelineLog.info("TwoPassSCSqlGenStrategy step one reqPrompt:{}", prompt.toSystemMessage());
Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage()); Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage());
String result = linkingResult.content().text(); String result = linkingResult.content().text();
keyPipelineLog.info("step one model response:{}", result); keyPipelineLog.info("TwoPassSCSqlGenStrategy step one modelResp:{}", result);
linkingResults.add(OutputFormat.getSchemaLink(result)); linkingResults.add(OutputFormat.getSchemaLink(result));
} }
); );
@@ -51,15 +51,14 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
List<String> sqlTaskPool = new CopyOnWriteArrayList<>(); List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
sqlPromptPool.parallelStream().forEach(sqlPrompt -> { sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>()); Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
keyPipelineLog.info("step two request prompt:{}", linkingPrompt.toSystemMessage()); keyPipelineLog.info("TwoPassSCSqlGenStrategy step two reqPrompt:{}", linkingPrompt.toSystemMessage());
Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage()); Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
String result = sqlResult.content().text(); String result = sqlResult.content().text();
keyPipelineLog.info("step two model response:{}", result); keyPipelineLog.info("TwoPassSCSqlGenStrategy step two modelResp:{}", result);
sqlTaskPool.add(result); sqlTaskPool.add(result);
}); });
//4.format response. //4.format response.
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool); Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
LLMResp llmResp = new LLMResp(); LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText()); llmResp.setQuery(llmReq.getQueryText());

View File

@@ -1,58 +0,0 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
@Slf4j
public class TwoPassSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
keyPipelineLog.info("llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
String linkingPromptStr = promptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
keyPipelineLog.info("step one model response:{}", response.content().text());
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
String generateSqlPrompt = promptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
keyPipelineLog.info("step two request prompt:{}", sqlPrompt.toSystemMessage());
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
String result = sqlResult.content().text();
keyPipelineLog.info("step two model response:{}", result);
Map<String, Double> sqlMap = new HashMap<>();
sqlMap.put(result, 1D);
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMap));
return llmResp;
}
@Override
public void afterPropertiesSet() {
SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT, this);
}
}

View File

@@ -68,16 +68,9 @@ public class LLMReq {
} }
public enum SqlGenType { public enum SqlGenType {
ONE_PASS_AUTO_COT("1_pass_auto_cot"),
ONE_PASS_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"), ONE_PASS_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"),
TWO_PASS_AUTO_COT("2_pass_auto_cot"),
TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency"); TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
private String name; private String name;
SqlGenType(String name) { SqlGenType(String name) {

View File

@@ -16,6 +16,7 @@ public class OptimizationConfig {
@Value("${s2.one.detection.size:8}") @Value("${s2.one.detection.size:8}")
private Integer oneDetectionSize; private Integer oneDetectionSize;
@Value("${s2.one.detection.max.size:20}") @Value("${s2.one.detection.max.size:20}")
private Integer oneDetectionMaxSize; private Integer oneDetectionMaxSize;
@@ -67,19 +68,19 @@ public class OptimizationConfig {
@Value("${s2.parser.linking.value.switch:true}") @Value("${s2.parser.linking.value.switch:true}")
private boolean useLinkingValueSwitch; private boolean useLinkingValueSwitch;
@Value("${s2.parser.generation:TWO_PASS_AUTO_COT}") @Value("${s2.parser.strategy:TWO_PASS_AUTO_COT_SELF_CONSISTENCY}")
private LLMReq.SqlGenType sqlGenType; private LLMReq.SqlGenType sqlGenType;
@Value("${s2.parser.use.switch:true}") @Value("${s2.parser.use.switch:true}")
private boolean useS2SqlSwitch; private boolean useS2SqlSwitch;
@Value("${s2.parser.exemplar-recall.num:15}") @Value("${s2.parser.exemplar-recall.number:15}")
private int text2sqlExampleNum; private int text2sqlExampleNum;
@Value("${s2.parser.few-shot.num:10}") @Value("${s2.parser.few-shot.number:5}")
private int text2sqlFewShotsNum; private int text2sqlFewShotsNum;
@Value("${s2.parser.self-consistency.num:5}") @Value("${s2.parser.self-consistency.number:5}")
private int text2sqlSelfConsistencyNum; private int text2sqlSelfConsistencyNum;
@Value("${s2.parser.show-count:3}") @Value("${s2.parser.show-count:3}")
@@ -89,83 +90,83 @@ public class OptimizationConfig {
private SysParameterService sysParameterService; private SysParameterService sysParameterService;
public Integer getOneDetectionSize() { public Integer getOneDetectionSize() {
return convertValue("one.detection.size", Integer.class, oneDetectionSize); return convertValue("s2.one.detection.size", Integer.class, oneDetectionSize);
} }
public Integer getOneDetectionMaxSize() { public Integer getOneDetectionMaxSize() {
return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize); return convertValue("s2.one.detection.max.size", Integer.class, oneDetectionMaxSize);
} }
public Double getMetricDimensionMinThresholdConfig() { public Double getMetricDimensionMinThresholdConfig() {
return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig); return convertValue("s2.metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
} }
public Double getMetricDimensionThresholdConfig() { public Double getMetricDimensionThresholdConfig() {
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig); return convertValue("s2.metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
} }
public Double getDimensionValueMinThresholdConfig() { public Double getDimensionValueMinThresholdConfig() {
return convertValue("dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig); return convertValue("s2.dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
} }
public Double getDimensionValueThresholdConfig() { public Double getDimensionValueThresholdConfig() {
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig); return convertValue("s2.dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
} }
public Double getLongTextThreshold() { public Double getLongTextThreshold() {
return convertValue("long.text.threshold", Double.class, longTextThreshold); return convertValue("s2.long.text.threshold", Double.class, longTextThreshold);
} }
public Double getShortTextThreshold() { public Double getShortTextThreshold() {
return convertValue("short.text.threshold", Double.class, shortTextThreshold); return convertValue("s2.short.text.threshold", Double.class, shortTextThreshold);
} }
public Integer getQueryTextLengthThreshold() { public Integer getQueryTextLengthThreshold() {
return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold); return convertValue("s2.query.text.length.threshold", Integer.class, queryTextLengthThreshold);
}
public boolean isUseS2SqlSwitch() {
return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch);
} }
public Integer getEmbeddingMapperWordMin() { public Integer getEmbeddingMapperWordMin() {
return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin); return convertValue("s2.embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
} }
public Integer getEmbeddingMapperWordMax() { public Integer getEmbeddingMapperWordMax() {
return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax); return convertValue("s2.embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
} }
public Integer getEmbeddingMapperBatch() { public Integer getEmbeddingMapperBatch() {
return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch); return convertValue("s2.embedding.mapper.batch", Integer.class, embeddingMapperBatch);
} }
public Integer getEmbeddingMapperNumber() { public Integer getEmbeddingMapperNumber() {
return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber); return convertValue("s2.embedding.mapper.number", Integer.class, embeddingMapperNumber);
} }
public Integer getEmbeddingMapperRoundNumber() { public Integer getEmbeddingMapperRoundNumber() {
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber); return convertValue("s2.embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
} }
public Double getEmbeddingMapperMinThreshold() { public Double getEmbeddingMapperMinThreshold() {
return convertValue("embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold); return convertValue("s2.embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
} }
public Double getEmbeddingMapperThreshold() { public Double getEmbeddingMapperThreshold() {
return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold); return convertValue("s2.embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
}
public boolean isUseS2SqlSwitch() {
return convertValue("s2.parser.use.switch", Boolean.class, useS2SqlSwitch);
} }
public boolean isUseLinkingValueSwitch() { public boolean isUseLinkingValueSwitch() {
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch); return convertValue("s2.parser.linking.value.switch", Boolean.class, useLinkingValueSwitch);
} }
public LLMReq.SqlGenType getSqlGenType() { public LLMReq.SqlGenType getSqlGenType() {
return convertValue("s2SQL.generation", LLMReq.SqlGenType.class, sqlGenType); return convertValue("s2.parser.strategy", LLMReq.SqlGenType.class, sqlGenType);
} }
public Integer getParseShowCount() { public Integer getParseShowCount() {
return convertValue("parse.show.count", Integer.class, parseShowCount); return convertValue("s2.parse.show-count", Integer.class, parseShowCount);
} }
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) { public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {

View File

@@ -75,7 +75,7 @@ public class SqlInfoProcessor implements ResultProcessor {
} }
SqlInfo sqlInfo = parseInfo.getSqlInfo(); SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (semanticQuery instanceof LLMSqlQuery) { if (semanticQuery instanceof LLMSqlQuery) {
keyPipelineLog.info("\nparsedS2SQL:{}\ncorrectedS2SQL:{}\nfinalSQL:{}", sqlInfo.getS2SQL(), keyPipelineLog.info("SqlInfoProcessor results:\nParsed S2SQL:{}\nCorrected S2SQL:{}\nFinal SQL:{}", sqlInfo.getS2SQL(),
sqlInfo.getCorrectS2SQL(), explainSql); sqlInfo.getCorrectS2SQL(), explainSql);
} }
sqlInfo.setQuerySQL(explainSql); sqlInfo.setQuerySQL(explainSql);

View File

@@ -41,12 +41,13 @@ s2:
parser: parser:
url: ${s2.pyllm.url} url: ${s2.pyllm.url}
strategy: TWO_PASS_AUTO_COT_SELF_CONSISTENCY
exemplar-recall: exemplar-recall:
number: 10 number: 10
few-shot: few-shot:
number: 5 number: 5
self-consistency: self-consistency:
number: 5 number: 1
multi-turn: multi-turn:
enable: false enable: false

View File

@@ -36,17 +36,14 @@ logging:
dev.ai4j.openai4j: DEBUG dev.ai4j.openai4j: DEBUG
s2: s2:
pyllm:
url: http://127.0.0.1:9092
parser: parser:
url: ${s2.pyllm.url} strategy: TWO_PASS_AUTO_COT_SELF_CONSISTENCY
exemplar-recall: exemplar-recall:
number: 10 number: 5
few-shot: few-shot:
number: 5 number: 1
self-consistency: self-consistency:
number: 5 number: 1
multi-turn: multi-turn:
enable: false enable: false
@@ -54,17 +51,9 @@ s2:
additional: additional:
information: true information: true
functionCall:
url: ${s2.pyllm.url}
embedding:
url: ${s2.pyllm.url}
persistent:
path: /tmp
demo: demo:
names: S2VisitsDemo,S2ArtistDemo names: S2VisitsDemo,S2ArtistDemo
enableLLM: true enableLLM: false
schema: schema:
cache: cache:
@@ -87,21 +76,3 @@ s2:
#2.1 in_memory(default) #2.1 in_memory(default)
embedding-model: embedding-model:
provider: in_process provider: in_process
# inProcess:
# modelPath: /data/model.onnx
# vocabularyPath: /data/onnx_vocab.txt
# shibing624/text2vec-base-chinese
#2.2 open_ai
# embedding-model:
# provider: open_ai
# openai:
# api-key: api_key
# modelName: all-minilm-l6-v2.onnx
#2.2 hugging_face
# embedding-model:
# provider: hugging_face
# hugging-face:
# access-token: hg_access_token
# model-id: sentence-transformers/all-MiniLM-L6-v2
# timeout: 1h