mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(headless)Remove redundant SqlGenStrategy impl
This commit is contained in:
@@ -95,13 +95,13 @@ public class MultiTurnParser implements ChatParser {
|
||||
variables.put("histSchema", context.getHistSchema());
|
||||
|
||||
Prompt prompt = promptTemplate.apply(variables);
|
||||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
||||
keyPipelineLog.info("MultiTurnParser reqPrompt:{}", prompt.toSystemMessage());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
|
||||
String result = response.content().text();
|
||||
keyPipelineLog.info("model response:{}", result);
|
||||
keyPipelineLog.info("MultiTurnParser modelResp:{}", result);
|
||||
return response.content().text();
|
||||
}
|
||||
|
||||
|
||||
@@ -48,62 +48,62 @@ public class SysParameter {
|
||||
admins = Lists.newArrayList("admin");
|
||||
|
||||
//detect config
|
||||
parameters.add(new Parameter("one.detection.size", "8",
|
||||
parameters.add(new Parameter("s2.one.detection.size", "8",
|
||||
"一次探测返回结果个数", "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数",
|
||||
"number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("one.detection.max.size", "20",
|
||||
parameters.add(new Parameter("s2.one.detection.max.size", "20",
|
||||
"一次探测前后缀匹配结果返回个数", "单次前后缀匹配返回的结果个数", "number", "Mapper相关配置"));
|
||||
|
||||
//mapper config
|
||||
parameters.add(new Parameter("metric.dimension.threshold", "0.3",
|
||||
parameters.add(new Parameter("s2.metric.dimension.threshold", "0.3",
|
||||
"指标名、维度名文本相似度阈值", "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("metric.dimension.min.threshold", "0.25",
|
||||
parameters.add(new Parameter("s2.metric.dimension.min.threshold", "0.25",
|
||||
"指标名、维度名最小文本相似度阈值", "指标名、维度名相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("dimension.value.threshold", "0.5",
|
||||
parameters.add(new Parameter("s2.dimension.value.threshold", "0.5",
|
||||
"维度值文本相似度阈值", "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("dimension.value.min.threshold", "0.3",
|
||||
parameters.add(new Parameter("s2.dimension.value.min.threshold", "0.3",
|
||||
"维度值最小文本相似度阈值", "维度值相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置"));
|
||||
|
||||
//embedding mapper config
|
||||
parameters.add(new Parameter("embedding.mapper.word.min",
|
||||
parameters.add(new Parameter("s2.embedding.mapper.word.min",
|
||||
"4", "用于向量召回最小的文本长度", "为提高向量召回效率, 小于该长度的文本不进行向量语义召回", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("embedding.mapper.word.max", "5",
|
||||
parameters.add(new Parameter("s2.embedding.mapper.word.max", "5",
|
||||
"用于向量召回最大的文本长度", "为提高向量召回效率, 大于该长度的文本不进行向量语义召回", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("embedding.mapper.batch", "50",
|
||||
parameters.add(new Parameter("s2.embedding.mapper.batch", "50",
|
||||
"批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("embedding.mapper.number", "5",
|
||||
parameters.add(new Parameter("s2.embedding.mapper.number", "5",
|
||||
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("embedding.mapper.threshold",
|
||||
parameters.add(new Parameter("s2.embedding.mapper.threshold",
|
||||
"0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置"));
|
||||
parameters.add(new Parameter("embedding.mapper.min.threshold",
|
||||
parameters.add(new Parameter("s2.embedding.mapper.min.threshold",
|
||||
"0.9", "向量召回最小相似度阈值", "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"));
|
||||
|
||||
//parser config
|
||||
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT",
|
||||
"S2SQL生成方式", "ONE_PASS_AUTO_COT: 通过思维链方式一步生成sql"
|
||||
+ "\nONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql"
|
||||
+ "\nTWO_PASS_AUTO_COT: 通过思维链方式两步生成sql"
|
||||
Parameter s2SQLParameter = new Parameter("s2.parser.strategy",
|
||||
"TWO_PASS_AUTO_COT_SELF_CONSISTENCY",
|
||||
"LLM解析生成S2SQL策略",
|
||||
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql"
|
||||
+ "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql", "list", "Parser相关配置");
|
||||
|
||||
s2SQLParameter.setCandidateValues(Lists.newArrayList("ONE_PASS_AUTO_COT", "ONE_PASS_AUTO_COT_SELF_CONSISTENCY",
|
||||
"TWO_PASS_AUTO_COT", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
|
||||
s2SQLParameter.setCandidateValues(Lists.newArrayList(
|
||||
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
|
||||
parameters.add(s2SQLParameter);
|
||||
parameters.add(new Parameter("s2SQL.linking.value.switch", "true",
|
||||
parameters.add(new Parameter("s2.s2SQL.linking.value.switch", "true",
|
||||
"是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择",
|
||||
"bool", "Parser相关配置"));
|
||||
parameters.add(new Parameter("query.text.length.threshold", "10",
|
||||
parameters.add(new Parameter("s2.query.text.length.threshold", "10",
|
||||
"用户输入文本长短阈值", "文本超过该阈值为长文本", "number", "Parser相关配置"));
|
||||
parameters.add(new Parameter("short.text.threshold", "0.5",
|
||||
parameters.add(new Parameter("s2.short.text.threshold", "0.5",
|
||||
"短文本匹配阈值", "由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number", "Parser相关配置"));
|
||||
parameters.add(new Parameter("long.text.threshold", "0.8",
|
||||
parameters.add(new Parameter("s2.long.text.threshold", "0.8",
|
||||
"长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number", "Parser相关配置"));
|
||||
parameters.add(new Parameter("parse.show.count", "3",
|
||||
parameters.add(new Parameter("s2.parse.show-count", "3",
|
||||
"解析结果个数", "前端展示的解析个数",
|
||||
"number", "Parser相关配置"));
|
||||
}
|
||||
|
||||
@@ -27,4 +27,13 @@ public class LLMConfig {
|
||||
this.apiKey = apiKey;
|
||||
this.modelName = modelName;
|
||||
}
|
||||
|
||||
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName,
|
||||
double temperature) {
|
||||
this.provider = provider;
|
||||
this.baseUrl = baseUrl;
|
||||
this.apiKey = apiKey;
|
||||
this.modelName = modelName;
|
||||
this.temperature = temperature;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq);
|
||||
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
@@ -39,12 +39,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
|
||||
.apply(new HashMap<>());
|
||||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = response.content().text();
|
||||
llmResults.add(result);
|
||||
keyPipelineLog.info("model response:{}", result);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:{}", result);
|
||||
}
|
||||
);
|
||||
//3.format response.
|
||||
@@ -56,7 +56,6 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
|
||||
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
|
||||
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
|
||||
|
||||
LLMResp result = new LLMResp();
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.retriever sqlExamples
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
||||
String promptStr = promptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = response.content().text();
|
||||
keyPipelineLog.info("model response:{}", result);
|
||||
//3.format response.
|
||||
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
|
||||
String sql = OutputFormat.getSql(response.content().text());
|
||||
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
|
||||
sqlRespMap.put(sql, LLMSqlResp.builder().sqlWeight(1D).fewShots(sqlExamples).build());
|
||||
keyPipelineLog.info("schemaLinkStr:{},sqlRespMap:{}", schemaLinkStr, sqlRespMap);
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
llmResp.setSqlRespMap(sqlRespMap);
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT, this);
|
||||
}
|
||||
}
|
||||
@@ -32,7 +32,7 @@ public class PythonLLMProxy implements LLMProxy {
|
||||
public LLMResp text2sql(LLMReq llmReq) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, llmReq:{}", llmReq);
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
keyPipelineLog.info("PythonLLMProxy llmReq:{}", llmReq);
|
||||
try {
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
|
||||
@@ -47,7 +47,7 @@ public class PythonLLMProxy implements LLMProxy {
|
||||
LLMResp llmResp = responseEntity.getBody();
|
||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||
System.currentTimeMillis() - startTime, url, entity, llmResp);
|
||||
keyPipelineLog.info("LLMResp:{}", llmResp);
|
||||
keyPipelineLog.info("PythonLLMProxy llmResp:{}", llmResp);
|
||||
|
||||
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));
|
||||
|
||||
@@ -23,7 +23,7 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
@@ -37,10 +37,10 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
linkingPromptPool.parallelStream().forEach(
|
||||
linkingPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one reqPrompt:{}", prompt.toSystemMessage());
|
||||
Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = linkingResult.content().text();
|
||||
keyPipelineLog.info("step one model response:{}", result);
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one modelResp:{}", result);
|
||||
linkingResults.add(OutputFormat.getSchemaLink(result));
|
||||
}
|
||||
);
|
||||
@@ -51,15 +51,14 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
|
||||
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
|
||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("step two request prompt:{}", linkingPrompt.toSystemMessage());
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two reqPrompt:{}", linkingPrompt.toSystemMessage());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
||||
String result = sqlResult.content().text();
|
||||
keyPipelineLog.info("step two model response:{}", result);
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two modelResp:{}", result);
|
||||
sqlTaskPool.add(result);
|
||||
});
|
||||
//4.format response.
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
|
||||
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TwoPassSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
String linkingPromptStr = promptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
keyPipelineLog.info("step one model response:{}", response.content().text());
|
||||
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
|
||||
String generateSqlPrompt = promptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
|
||||
|
||||
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("step two request prompt:{}", sqlPrompt.toSystemMessage());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
|
||||
String result = sqlResult.content().text();
|
||||
keyPipelineLog.info("step two model response:{}", result);
|
||||
Map<String, Double> sqlMap = new HashMap<>();
|
||||
sqlMap.put(result, 1D);
|
||||
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMap));
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT, this);
|
||||
}
|
||||
}
|
||||
@@ -68,16 +68,9 @@ public class LLMReq {
|
||||
}
|
||||
|
||||
public enum SqlGenType {
|
||||
|
||||
ONE_PASS_AUTO_COT("1_pass_auto_cot"),
|
||||
|
||||
ONE_PASS_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"),
|
||||
|
||||
TWO_PASS_AUTO_COT("2_pass_auto_cot"),
|
||||
|
||||
TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
|
||||
|
||||
|
||||
private String name;
|
||||
|
||||
SqlGenType(String name) {
|
||||
|
||||
@@ -16,6 +16,7 @@ public class OptimizationConfig {
|
||||
|
||||
@Value("${s2.one.detection.size:8}")
|
||||
private Integer oneDetectionSize;
|
||||
|
||||
@Value("${s2.one.detection.max.size:20}")
|
||||
private Integer oneDetectionMaxSize;
|
||||
|
||||
@@ -67,19 +68,19 @@ public class OptimizationConfig {
|
||||
@Value("${s2.parser.linking.value.switch:true}")
|
||||
private boolean useLinkingValueSwitch;
|
||||
|
||||
@Value("${s2.parser.generation:TWO_PASS_AUTO_COT}")
|
||||
@Value("${s2.parser.strategy:TWO_PASS_AUTO_COT_SELF_CONSISTENCY}")
|
||||
private LLMReq.SqlGenType sqlGenType;
|
||||
|
||||
@Value("${s2.parser.use.switch:true}")
|
||||
private boolean useS2SqlSwitch;
|
||||
|
||||
@Value("${s2.parser.exemplar-recall.num:15}")
|
||||
@Value("${s2.parser.exemplar-recall.number:15}")
|
||||
private int text2sqlExampleNum;
|
||||
|
||||
@Value("${s2.parser.few-shot.num:10}")
|
||||
@Value("${s2.parser.few-shot.number:5}")
|
||||
private int text2sqlFewShotsNum;
|
||||
|
||||
@Value("${s2.parser.self-consistency.num:5}")
|
||||
@Value("${s2.parser.self-consistency.number:5}")
|
||||
private int text2sqlSelfConsistencyNum;
|
||||
|
||||
@Value("${s2.parser.show-count:3}")
|
||||
@@ -89,83 +90,83 @@ public class OptimizationConfig {
|
||||
private SysParameterService sysParameterService;
|
||||
|
||||
public Integer getOneDetectionSize() {
|
||||
return convertValue("one.detection.size", Integer.class, oneDetectionSize);
|
||||
return convertValue("s2.one.detection.size", Integer.class, oneDetectionSize);
|
||||
}
|
||||
|
||||
public Integer getOneDetectionMaxSize() {
|
||||
return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize);
|
||||
return convertValue("s2.one.detection.max.size", Integer.class, oneDetectionMaxSize);
|
||||
}
|
||||
|
||||
public Double getMetricDimensionMinThresholdConfig() {
|
||||
return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
|
||||
return convertValue("s2.metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getMetricDimensionThresholdConfig() {
|
||||
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
|
||||
return convertValue("s2.metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getDimensionValueMinThresholdConfig() {
|
||||
return convertValue("dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
|
||||
return convertValue("s2.dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getDimensionValueThresholdConfig() {
|
||||
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
|
||||
return convertValue("s2.dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getLongTextThreshold() {
|
||||
return convertValue("long.text.threshold", Double.class, longTextThreshold);
|
||||
return convertValue("s2.long.text.threshold", Double.class, longTextThreshold);
|
||||
}
|
||||
|
||||
public Double getShortTextThreshold() {
|
||||
return convertValue("short.text.threshold", Double.class, shortTextThreshold);
|
||||
return convertValue("s2.short.text.threshold", Double.class, shortTextThreshold);
|
||||
}
|
||||
|
||||
public Integer getQueryTextLengthThreshold() {
|
||||
return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold);
|
||||
}
|
||||
|
||||
public boolean isUseS2SqlSwitch() {
|
||||
return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch);
|
||||
return convertValue("s2.query.text.length.threshold", Integer.class, queryTextLengthThreshold);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperWordMin() {
|
||||
return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
|
||||
return convertValue("s2.embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperWordMax() {
|
||||
return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
|
||||
return convertValue("s2.embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperBatch() {
|
||||
return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch);
|
||||
return convertValue("s2.embedding.mapper.batch", Integer.class, embeddingMapperBatch);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperNumber() {
|
||||
return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber);
|
||||
return convertValue("s2.embedding.mapper.number", Integer.class, embeddingMapperNumber);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperRoundNumber() {
|
||||
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||
return convertValue("s2.embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperMinThreshold() {
|
||||
return convertValue("embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
|
||||
return convertValue("s2.embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperThreshold() {
|
||||
return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
|
||||
return convertValue("s2.embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
|
||||
}
|
||||
|
||||
public boolean isUseS2SqlSwitch() {
|
||||
return convertValue("s2.parser.use.switch", Boolean.class, useS2SqlSwitch);
|
||||
}
|
||||
|
||||
public boolean isUseLinkingValueSwitch() {
|
||||
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||
return convertValue("s2.parser.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||
}
|
||||
|
||||
public LLMReq.SqlGenType getSqlGenType() {
|
||||
return convertValue("s2SQL.generation", LLMReq.SqlGenType.class, sqlGenType);
|
||||
return convertValue("s2.parser.strategy", LLMReq.SqlGenType.class, sqlGenType);
|
||||
}
|
||||
|
||||
public Integer getParseShowCount() {
|
||||
return convertValue("parse.show.count", Integer.class, parseShowCount);
|
||||
return convertValue("s2.parse.show-count", Integer.class, parseShowCount);
|
||||
}
|
||||
|
||||
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
|
||||
|
||||
@@ -75,7 +75,7 @@ public class SqlInfoProcessor implements ResultProcessor {
|
||||
}
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (semanticQuery instanceof LLMSqlQuery) {
|
||||
keyPipelineLog.info("\nparsedS2SQL:{}\ncorrectedS2SQL:{}\nfinalSQL:{}", sqlInfo.getS2SQL(),
|
||||
keyPipelineLog.info("SqlInfoProcessor results:\nParsed S2SQL:{}\nCorrected S2SQL:{}\nFinal SQL:{}", sqlInfo.getS2SQL(),
|
||||
sqlInfo.getCorrectS2SQL(), explainSql);
|
||||
}
|
||||
sqlInfo.setQuerySQL(explainSql);
|
||||
|
||||
@@ -41,12 +41,13 @@ s2:
|
||||
|
||||
parser:
|
||||
url: ${s2.pyllm.url}
|
||||
strategy: TWO_PASS_AUTO_COT_SELF_CONSISTENCY
|
||||
exemplar-recall:
|
||||
number: 10
|
||||
few-shot:
|
||||
number: 5
|
||||
self-consistency:
|
||||
number: 5
|
||||
number: 1
|
||||
multi-turn:
|
||||
enable: false
|
||||
|
||||
|
||||
@@ -36,17 +36,14 @@ logging:
|
||||
dev.ai4j.openai4j: DEBUG
|
||||
|
||||
s2:
|
||||
pyllm:
|
||||
url: http://127.0.0.1:9092
|
||||
|
||||
parser:
|
||||
url: ${s2.pyllm.url}
|
||||
strategy: TWO_PASS_AUTO_COT_SELF_CONSISTENCY
|
||||
exemplar-recall:
|
||||
number: 10
|
||||
number: 5
|
||||
few-shot:
|
||||
number: 5
|
||||
number: 1
|
||||
self-consistency:
|
||||
number: 5
|
||||
number: 1
|
||||
multi-turn:
|
||||
enable: false
|
||||
|
||||
@@ -54,17 +51,9 @@ s2:
|
||||
additional:
|
||||
information: true
|
||||
|
||||
functionCall:
|
||||
url: ${s2.pyllm.url}
|
||||
|
||||
embedding:
|
||||
url: ${s2.pyllm.url}
|
||||
persistent:
|
||||
path: /tmp
|
||||
|
||||
demo:
|
||||
names: S2VisitsDemo,S2ArtistDemo
|
||||
enableLLM: true
|
||||
enableLLM: false
|
||||
|
||||
schema:
|
||||
cache:
|
||||
@@ -86,22 +75,4 @@ s2:
|
||||
#2.embedding-model
|
||||
#2.1 in_memory(default)
|
||||
embedding-model:
|
||||
provider: in_process
|
||||
# inProcess:
|
||||
# modelPath: /data/model.onnx
|
||||
# vocabularyPath: /data/onnx_vocab.txt
|
||||
# shibing624/text2vec-base-chinese
|
||||
#2.2 open_ai
|
||||
# embedding-model:
|
||||
# provider: open_ai
|
||||
# openai:
|
||||
# api-key: api_key
|
||||
# modelName: all-minilm-l6-v2.onnx
|
||||
|
||||
#2.2 hugging_face
|
||||
# embedding-model:
|
||||
# provider: hugging_face
|
||||
# hugging-face:
|
||||
# access-token: hg_access_token
|
||||
# model-id: sentence-transformers/all-MiniLM-L6-v2
|
||||
# timeout: 1h
|
||||
provider: in_process
|
||||
Reference in New Issue
Block a user