(improvement)(headless)Remove redundant SqlGenStrategy impl

This commit is contained in:
jerryjzhang
2024-05-30 15:23:52 +08:00
parent 3b09a5c0ed
commit 18f268f590
13 changed files with 81 additions and 222 deletions

View File

@@ -95,13 +95,13 @@ public class MultiTurnParser implements ChatParser {
variables.put("histSchema", context.getHistSchema());
Prompt prompt = promptTemplate.apply(variables);
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
keyPipelineLog.info("MultiTurnParser reqPrompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
keyPipelineLog.info("model response:{}", result);
keyPipelineLog.info("MultiTurnParser modelResp:{}", result);
return response.content().text();
}

View File

@@ -48,62 +48,62 @@ public class SysParameter {
admins = Lists.newArrayList("admin");
//detect config
parameters.add(new Parameter("one.detection.size", "8",
parameters.add(new Parameter("s2.one.detection.size", "8",
"一次探测返回结果个数", "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数",
"number", "Mapper相关配置"));
parameters.add(new Parameter("one.detection.max.size", "20",
parameters.add(new Parameter("s2.one.detection.max.size", "20",
"一次探测前后缀匹配结果返回个数", "单次前后缀匹配返回的结果个数", "number", "Mapper相关配置"));
//mapper config
parameters.add(new Parameter("metric.dimension.threshold", "0.3",
parameters.add(new Parameter("s2.metric.dimension.threshold", "0.3",
"指标名、维度名文本相似度阈值", "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
"number", "Mapper相关配置"));
parameters.add(new Parameter("metric.dimension.min.threshold", "0.25",
parameters.add(new Parameter("s2.metric.dimension.min.threshold", "0.25",
"指标名、维度名最小文本相似度阈值", "指标名、维度名相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置"));
parameters.add(new Parameter("dimension.value.threshold", "0.5",
parameters.add(new Parameter("s2.dimension.value.threshold", "0.5",
"维度值文本相似度阈值", "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
"number", "Mapper相关配置"));
parameters.add(new Parameter("dimension.value.min.threshold", "0.3",
parameters.add(new Parameter("s2.dimension.value.min.threshold", "0.3",
"维度值最小文本相似度阈值", "维度值相似度阈值在动态调整中的最低值",
"number", "Mapper相关配置"));
//embedding mapper config
parameters.add(new Parameter("embedding.mapper.word.min",
parameters.add(new Parameter("s2.embedding.mapper.word.min",
"4", "用于向量召回最小的文本长度", "为提高向量召回效率, 小于该长度的文本不进行向量语义召回", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.word.max", "5",
parameters.add(new Parameter("s2.embedding.mapper.word.max", "5",
"用于向量召回最大的文本长度", "为提高向量召回效率, 大于该长度的文本不进行向量语义召回", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.batch", "50",
parameters.add(new Parameter("s2.embedding.mapper.batch", "50",
"批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.number", "5",
parameters.add(new Parameter("s2.embedding.mapper.number", "5",
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.threshold",
parameters.add(new Parameter("s2.embedding.mapper.threshold",
"0.99", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.min.threshold",
parameters.add(new Parameter("s2.embedding.mapper.min.threshold",
"0.9", "向量召回最小相似度阈值", "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"));
//parser config
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT",
"S2SQL生成方式", "ONE_PASS_AUTO_COT: 通过思维链方式一步生成sql"
+ "\nONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式一步生成sql"
+ "\nTWO_PASS_AUTO_COT: 通过思维链方式步生成sql"
Parameter s2SQLParameter = new Parameter("s2.parser.strategy",
"TWO_PASS_AUTO_COT_SELF_CONSISTENCY",
"LLM解析生成S2SQL策略",
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式步生成sql"
+ "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql", "list", "Parser相关配置");
s2SQLParameter.setCandidateValues(Lists.newArrayList("ONE_PASS_AUTO_COT", "ONE_PASS_AUTO_COT_SELF_CONSISTENCY",
"TWO_PASS_AUTO_COT", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
s2SQLParameter.setCandidateValues(Lists.newArrayList(
"ONE_PASS_AUTO_COT_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
parameters.add(s2SQLParameter);
parameters.add(new Parameter("s2SQL.linking.value.switch", "true",
parameters.add(new Parameter("s2.s2SQL.linking.value.switch", "true",
"是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择",
"bool", "Parser相关配置"));
parameters.add(new Parameter("query.text.length.threshold", "10",
parameters.add(new Parameter("s2.query.text.length.threshold", "10",
"用户输入文本长短阈值", "文本超过该阈值为长文本", "number", "Parser相关配置"));
parameters.add(new Parameter("short.text.threshold", "0.5",
parameters.add(new Parameter("s2.short.text.threshold", "0.5",
"短文本匹配阈值", "由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置"));
parameters.add(new Parameter("long.text.threshold", "0.8",
parameters.add(new Parameter("s2.long.text.threshold", "0.8",
"长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置"));
parameters.add(new Parameter("parse.show.count", "3",
parameters.add(new Parameter("s2.parse.show-count", "3",
"解析结果个数", "前端展示的解析个数",
"number", "Parser相关配置"));
}

View File

@@ -27,4 +27,13 @@ public class LLMConfig {
this.apiKey = apiKey;
this.modelName = modelName;
}
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName,
double temperature) {
this.provider = provider;
this.baseUrl = baseUrl;
this.apiKey = apiKey;
this.modelName = modelName;
this.temperature = temperature;
}
}

View File

@@ -25,7 +25,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("llmReq:{}", llmReq);
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
@@ -39,12 +39,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
.apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
llmResults.add(result);
keyPipelineLog.info("model response:{}", result);
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:{}", result);
}
);
//3.format response.
@@ -56,7 +56,6 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
LLMResp result = new LLMResp();
result.setQuery(llmReq.getQueryText());

View File

@@ -1,56 +0,0 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
@Slf4j
public class OnePassSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples
keyPipelineLog.info("llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
//2.generator linking and sql prompt by sqlExamples,and generate response.
String promptStr = promptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
keyPipelineLog.info("model response:{}", result);
//3.format response.
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
String sql = OutputFormat.getSql(response.content().text());
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
sqlRespMap.put(sql, LLMSqlResp.builder().sqlWeight(1D).fewShots(sqlExamples).build());
keyPipelineLog.info("schemaLinkStr:{},sqlRespMap:{}", schemaLinkStr, sqlRespMap);
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(sqlRespMap);
return llmResp;
}
@Override
public void afterPropertiesSet() {
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT, this);
}
}

View File

@@ -32,7 +32,7 @@ public class PythonLLMProxy implements LLMProxy {
public LLMResp text2sql(LLMReq llmReq) {
long startTime = System.currentTimeMillis();
log.info("requestLLM request, llmReq:{}", llmReq);
keyPipelineLog.info("llmReq:{}", llmReq);
keyPipelineLog.info("PythonLLMProxy llmReq:{}", llmReq);
try {
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
@@ -47,7 +47,7 @@ public class PythonLLMProxy implements LLMProxy {
LLMResp llmResp = responseEntity.getBody();
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url, entity, llmResp);
keyPipelineLog.info("LLMResp:{}", llmResp);
keyPipelineLog.info("PythonLLMProxy llmResp:{}", llmResp);
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));

View File

@@ -23,7 +23,7 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("llmReq:{}", llmReq);
keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
@@ -37,10 +37,10 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
linkingPromptPool.parallelStream().forEach(
linkingPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one reqPrompt:{}", prompt.toSystemMessage());
Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage());
String result = linkingResult.content().text();
keyPipelineLog.info("step one model response:{}", result);
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one modelResp:{}", result);
linkingResults.add(OutputFormat.getSchemaLink(result));
}
);
@@ -51,15 +51,14 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
keyPipelineLog.info("step two request prompt:{}", linkingPrompt.toSystemMessage());
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two reqPrompt:{}", linkingPrompt.toSystemMessage());
Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
String result = sqlResult.content().text();
keyPipelineLog.info("step two model response:{}", result);
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two modelResp:{}", result);
sqlTaskPool.add(result);
});
//4.format response.
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());

View File

@@ -1,58 +0,0 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
@Slf4j
public class TwoPassSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
keyPipelineLog.info("llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
String linkingPromptStr = promptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
keyPipelineLog.info("step one model response:{}", response.content().text());
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
String generateSqlPrompt = promptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
keyPipelineLog.info("step two request prompt:{}", sqlPrompt.toSystemMessage());
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
String result = sqlResult.content().text();
keyPipelineLog.info("step two model response:{}", result);
Map<String, Double> sqlMap = new HashMap<>();
sqlMap.put(result, 1D);
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMap));
return llmResp;
}
@Override
public void afterPropertiesSet() {
SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT, this);
}
}

View File

@@ -68,16 +68,9 @@ public class LLMReq {
}
public enum SqlGenType {
ONE_PASS_AUTO_COT("1_pass_auto_cot"),
ONE_PASS_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"),
TWO_PASS_AUTO_COT("2_pass_auto_cot"),
TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
private String name;
SqlGenType(String name) {

View File

@@ -16,6 +16,7 @@ public class OptimizationConfig {
@Value("${s2.one.detection.size:8}")
private Integer oneDetectionSize;
@Value("${s2.one.detection.max.size:20}")
private Integer oneDetectionMaxSize;
@@ -67,19 +68,19 @@ public class OptimizationConfig {
@Value("${s2.parser.linking.value.switch:true}")
private boolean useLinkingValueSwitch;
@Value("${s2.parser.generation:TWO_PASS_AUTO_COT}")
@Value("${s2.parser.strategy:TWO_PASS_AUTO_COT_SELF_CONSISTENCY}")
private LLMReq.SqlGenType sqlGenType;
@Value("${s2.parser.use.switch:true}")
private boolean useS2SqlSwitch;
@Value("${s2.parser.exemplar-recall.num:15}")
@Value("${s2.parser.exemplar-recall.number:15}")
private int text2sqlExampleNum;
@Value("${s2.parser.few-shot.num:10}")
@Value("${s2.parser.few-shot.number:5}")
private int text2sqlFewShotsNum;
@Value("${s2.parser.self-consistency.num:5}")
@Value("${s2.parser.self-consistency.number:5}")
private int text2sqlSelfConsistencyNum;
@Value("${s2.parser.show-count:3}")
@@ -89,83 +90,83 @@ public class OptimizationConfig {
private SysParameterService sysParameterService;
public Integer getOneDetectionSize() {
return convertValue("one.detection.size", Integer.class, oneDetectionSize);
return convertValue("s2.one.detection.size", Integer.class, oneDetectionSize);
}
public Integer getOneDetectionMaxSize() {
return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize);
return convertValue("s2.one.detection.max.size", Integer.class, oneDetectionMaxSize);
}
public Double getMetricDimensionMinThresholdConfig() {
return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
return convertValue("s2.metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
}
public Double getMetricDimensionThresholdConfig() {
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
return convertValue("s2.metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
}
public Double getDimensionValueMinThresholdConfig() {
return convertValue("dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
return convertValue("s2.dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
}
public Double getDimensionValueThresholdConfig() {
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
return convertValue("s2.dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
}
public Double getLongTextThreshold() {
return convertValue("long.text.threshold", Double.class, longTextThreshold);
return convertValue("s2.long.text.threshold", Double.class, longTextThreshold);
}
public Double getShortTextThreshold() {
return convertValue("short.text.threshold", Double.class, shortTextThreshold);
return convertValue("s2.short.text.threshold", Double.class, shortTextThreshold);
}
public Integer getQueryTextLengthThreshold() {
return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold);
}
public boolean isUseS2SqlSwitch() {
return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch);
return convertValue("s2.query.text.length.threshold", Integer.class, queryTextLengthThreshold);
}
public Integer getEmbeddingMapperWordMin() {
return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
return convertValue("s2.embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
}
public Integer getEmbeddingMapperWordMax() {
return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
return convertValue("s2.embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
}
public Integer getEmbeddingMapperBatch() {
return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch);
return convertValue("s2.embedding.mapper.batch", Integer.class, embeddingMapperBatch);
}
public Integer getEmbeddingMapperNumber() {
return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber);
return convertValue("s2.embedding.mapper.number", Integer.class, embeddingMapperNumber);
}
public Integer getEmbeddingMapperRoundNumber() {
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
return convertValue("s2.embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
}
public Double getEmbeddingMapperMinThreshold() {
return convertValue("embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
return convertValue("s2.embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
}
public Double getEmbeddingMapperThreshold() {
return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
return convertValue("s2.embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
}
public boolean isUseS2SqlSwitch() {
return convertValue("s2.parser.use.switch", Boolean.class, useS2SqlSwitch);
}
public boolean isUseLinkingValueSwitch() {
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
return convertValue("s2.parser.linking.value.switch", Boolean.class, useLinkingValueSwitch);
}
public LLMReq.SqlGenType getSqlGenType() {
return convertValue("s2SQL.generation", LLMReq.SqlGenType.class, sqlGenType);
return convertValue("s2.parser.strategy", LLMReq.SqlGenType.class, sqlGenType);
}
public Integer getParseShowCount() {
return convertValue("parse.show.count", Integer.class, parseShowCount);
return convertValue("s2.parse.show-count", Integer.class, parseShowCount);
}
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {

View File

@@ -75,7 +75,7 @@ public class SqlInfoProcessor implements ResultProcessor {
}
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (semanticQuery instanceof LLMSqlQuery) {
keyPipelineLog.info("\nparsedS2SQL:{}\ncorrectedS2SQL:{}\nfinalSQL:{}", sqlInfo.getS2SQL(),
keyPipelineLog.info("SqlInfoProcessor results:\nParsed S2SQL:{}\nCorrected S2SQL:{}\nFinal SQL:{}", sqlInfo.getS2SQL(),
sqlInfo.getCorrectS2SQL(), explainSql);
}
sqlInfo.setQuerySQL(explainSql);

View File

@@ -41,12 +41,13 @@ s2:
parser:
url: ${s2.pyllm.url}
strategy: TWO_PASS_AUTO_COT_SELF_CONSISTENCY
exemplar-recall:
number: 10
few-shot:
number: 5
self-consistency:
number: 5
number: 1
multi-turn:
enable: false

View File

@@ -36,17 +36,14 @@ logging:
dev.ai4j.openai4j: DEBUG
s2:
pyllm:
url: http://127.0.0.1:9092
parser:
url: ${s2.pyllm.url}
strategy: TWO_PASS_AUTO_COT_SELF_CONSISTENCY
exemplar-recall:
number: 10
number: 5
few-shot:
number: 5
number: 1
self-consistency:
number: 5
number: 1
multi-turn:
enable: false
@@ -54,17 +51,9 @@ s2:
additional:
information: true
functionCall:
url: ${s2.pyllm.url}
embedding:
url: ${s2.pyllm.url}
persistent:
path: /tmp
demo:
names: S2VisitsDemo,S2ArtistDemo
enableLLM: true
enableLLM: false
schema:
cache:
@@ -86,22 +75,4 @@ s2:
#2.embedding-model
#2.1 in_memory(default)
embedding-model:
provider: in_process
# inProcess:
# modelPath: /data/model.onnx
# vocabularyPath: /data/onnx_vocab.txt
# shibing624/text2vec-base-chinese
#2.2 open_ai
# embedding-model:
# provider: open_ai
# openai:
# api-key: api_key
# modelName: all-minilm-l6-v2.onnx
#2.2 hugging_face
# embedding-model:
# provider: hugging_face
# hugging-face:
# access-token: hg_access_token
# model-id: sentence-transformers/all-MiniLM-L6-v2
# timeout: 1h
provider: in_process