(improvement)(headless)Remove redundant SqlGenStrategy impl

This commit is contained in:
jerryjzhang
2024-05-30 15:23:52 +08:00
parent 3b09a5c0ed
commit 18f268f590
13 changed files with 81 additions and 222 deletions

View File

@@ -25,7 +25,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("llmReq:{}", llmReq);
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
@@ -39,12 +39,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
.apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
llmResults.add(result);
keyPipelineLog.info("model response:{}", result);
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:{}", result);
}
);
//3.format response.
@@ -56,7 +56,6 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
LLMResp result = new LLMResp();
result.setQuery(llmReq.getQueryText());

View File

@@ -1,56 +0,0 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
@Slf4j
public class OnePassSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples
keyPipelineLog.info("llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
//2.generator linking and sql prompt by sqlExamples,and generate response.
String promptStr = promptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
keyPipelineLog.info("model response:{}", result);
//3.format response.
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
String sql = OutputFormat.getSql(response.content().text());
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
sqlRespMap.put(sql, LLMSqlResp.builder().sqlWeight(1D).fewShots(sqlExamples).build());
keyPipelineLog.info("schemaLinkStr:{},sqlRespMap:{}", schemaLinkStr, sqlRespMap);
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(sqlRespMap);
return llmResp;
}
@Override
public void afterPropertiesSet() {
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT, this);
}
}

View File

@@ -32,7 +32,7 @@ public class PythonLLMProxy implements LLMProxy {
public LLMResp text2sql(LLMReq llmReq) {
long startTime = System.currentTimeMillis();
log.info("requestLLM request, llmReq:{}", llmReq);
keyPipelineLog.info("llmReq:{}", llmReq);
keyPipelineLog.info("PythonLLMProxy llmReq:{}", llmReq);
try {
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
@@ -47,7 +47,7 @@ public class PythonLLMProxy implements LLMProxy {
LLMResp llmResp = responseEntity.getBody();
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
System.currentTimeMillis() - startTime, url, entity, llmResp);
keyPipelineLog.info("LLMResp:{}", llmResp);
keyPipelineLog.info("PythonLLMProxy llmResp:{}", llmResp);
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));

View File

@@ -23,7 +23,7 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
//1.retriever sqlExamples and generate exampleListPool
keyPipelineLog.info("llmReq:{}", llmReq);
keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
@@ -37,10 +37,10 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
linkingPromptPool.parallelStream().forEach(
linkingPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one reqPrompt:{}", prompt.toSystemMessage());
Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage());
String result = linkingResult.content().text();
keyPipelineLog.info("step one model response:{}", result);
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one modelResp:{}", result);
linkingResults.add(OutputFormat.getSchemaLink(result));
}
);
@@ -51,15 +51,14 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
keyPipelineLog.info("step two request prompt:{}", linkingPrompt.toSystemMessage());
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two reqPrompt:{}", linkingPrompt.toSystemMessage());
Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
String result = sqlResult.content().text();
keyPipelineLog.info("step two model response:{}", result);
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two modelResp:{}", result);
sqlTaskPool.add(result);
});
//4.format response.
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());

View File

@@ -1,58 +0,0 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@Service
@Slf4j
public class TwoPassSqlGenStrategy extends SqlGenStrategy {
@Override
public LLMResp generate(LLMReq llmReq) {
keyPipelineLog.info("llmReq:{}", llmReq);
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
String linkingPromptStr = promptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
keyPipelineLog.info("step one model response:{}", response.content().text());
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
String generateSqlPrompt = promptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
keyPipelineLog.info("step two request prompt:{}", sqlPrompt.toSystemMessage());
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
String result = sqlResult.content().text();
keyPipelineLog.info("step two model response:{}", result);
Map<String, Double> sqlMap = new HashMap<>();
sqlMap.put(result, 1D);
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMap));
return llmResp;
}
@Override
public void afterPropertiesSet() {
SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT, this);
}
}

View File

@@ -68,16 +68,9 @@ public class LLMReq {
}
public enum SqlGenType {
ONE_PASS_AUTO_COT("1_pass_auto_cot"),
ONE_PASS_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"),
TWO_PASS_AUTO_COT("2_pass_auto_cot"),
TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
private String name;
SqlGenType(String name) {

View File

@@ -16,6 +16,7 @@ public class OptimizationConfig {
@Value("${s2.one.detection.size:8}")
private Integer oneDetectionSize;
@Value("${s2.one.detection.max.size:20}")
private Integer oneDetectionMaxSize;
@@ -67,19 +68,19 @@ public class OptimizationConfig {
@Value("${s2.parser.linking.value.switch:true}")
private boolean useLinkingValueSwitch;
@Value("${s2.parser.generation:TWO_PASS_AUTO_COT}")
@Value("${s2.parser.strategy:TWO_PASS_AUTO_COT_SELF_CONSISTENCY}")
private LLMReq.SqlGenType sqlGenType;
@Value("${s2.parser.use.switch:true}")
private boolean useS2SqlSwitch;
@Value("${s2.parser.exemplar-recall.num:15}")
@Value("${s2.parser.exemplar-recall.number:15}")
private int text2sqlExampleNum;
@Value("${s2.parser.few-shot.num:10}")
@Value("${s2.parser.few-shot.number:5}")
private int text2sqlFewShotsNum;
@Value("${s2.parser.self-consistency.num:5}")
@Value("${s2.parser.self-consistency.number:5}")
private int text2sqlSelfConsistencyNum;
@Value("${s2.parser.show-count:3}")
@@ -89,83 +90,83 @@ public class OptimizationConfig {
private SysParameterService sysParameterService;
public Integer getOneDetectionSize() {
return convertValue("one.detection.size", Integer.class, oneDetectionSize);
return convertValue("s2.one.detection.size", Integer.class, oneDetectionSize);
}
public Integer getOneDetectionMaxSize() {
return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize);
return convertValue("s2.one.detection.max.size", Integer.class, oneDetectionMaxSize);
}
public Double getMetricDimensionMinThresholdConfig() {
return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
return convertValue("s2.metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
}
public Double getMetricDimensionThresholdConfig() {
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
return convertValue("s2.metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
}
public Double getDimensionValueMinThresholdConfig() {
return convertValue("dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
return convertValue("s2.dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
}
public Double getDimensionValueThresholdConfig() {
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
return convertValue("s2.dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
}
public Double getLongTextThreshold() {
return convertValue("long.text.threshold", Double.class, longTextThreshold);
return convertValue("s2.long.text.threshold", Double.class, longTextThreshold);
}
public Double getShortTextThreshold() {
return convertValue("short.text.threshold", Double.class, shortTextThreshold);
return convertValue("s2.short.text.threshold", Double.class, shortTextThreshold);
}
public Integer getQueryTextLengthThreshold() {
return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold);
}
public boolean isUseS2SqlSwitch() {
return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch);
return convertValue("s2.query.text.length.threshold", Integer.class, queryTextLengthThreshold);
}
public Integer getEmbeddingMapperWordMin() {
return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
return convertValue("s2.embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
}
public Integer getEmbeddingMapperWordMax() {
return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
return convertValue("s2.embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
}
public Integer getEmbeddingMapperBatch() {
return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch);
return convertValue("s2.embedding.mapper.batch", Integer.class, embeddingMapperBatch);
}
public Integer getEmbeddingMapperNumber() {
return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber);
return convertValue("s2.embedding.mapper.number", Integer.class, embeddingMapperNumber);
}
public Integer getEmbeddingMapperRoundNumber() {
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
return convertValue("s2.embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
}
public Double getEmbeddingMapperMinThreshold() {
return convertValue("embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
return convertValue("s2.embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
}
public Double getEmbeddingMapperThreshold() {
return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
return convertValue("s2.embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
}
public boolean isUseS2SqlSwitch() {
return convertValue("s2.parser.use.switch", Boolean.class, useS2SqlSwitch);
}
public boolean isUseLinkingValueSwitch() {
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
return convertValue("s2.parser.linking.value.switch", Boolean.class, useLinkingValueSwitch);
}
public LLMReq.SqlGenType getSqlGenType() {
return convertValue("s2SQL.generation", LLMReq.SqlGenType.class, sqlGenType);
return convertValue("s2.parser.strategy", LLMReq.SqlGenType.class, sqlGenType);
}
public Integer getParseShowCount() {
return convertValue("parse.show.count", Integer.class, parseShowCount);
return convertValue("s2.parse.show-count", Integer.class, parseShowCount);
}
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {