mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 13:47:09 +00:00
(improvement)(headless)Remove redundant SqlGenStrategy impl
This commit is contained in:
@@ -25,7 +25,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:{}", llmReq);
|
||||
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
@@ -39,12 +39,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
linkingSqlPromptPool.parallelStream().forEach(linkingSqlPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
|
||||
.apply(new HashMap<>());
|
||||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = response.content().text();
|
||||
llmResults.add(result);
|
||||
keyPipelineLog.info("model response:{}", result);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:{}", result);
|
||||
}
|
||||
);
|
||||
//3.format response.
|
||||
@@ -56,7 +56,6 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
.map(llmResult -> OutputFormat.getSql(llmResult)).collect(Collectors.toList());
|
||||
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlList);
|
||||
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
|
||||
|
||||
LLMResp result = new LLMResp();
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.retriever sqlExamples
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
//2.generator linking and sql prompt by sqlExamples,and generate response.
|
||||
String promptStr = promptGenerator.generatorLinkingAndSqlPrompt(llmReq, sqlExamples);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = response.content().text();
|
||||
keyPipelineLog.info("model response:{}", result);
|
||||
//3.format response.
|
||||
String schemaLinkStr = OutputFormat.getSchemaLinks(response.content().text());
|
||||
String sql = OutputFormat.getSql(response.content().text());
|
||||
Map<String, LLMSqlResp> sqlRespMap = new HashMap<>();
|
||||
sqlRespMap.put(sql, LLMSqlResp.builder().sqlWeight(1D).fewShots(sqlExamples).build());
|
||||
keyPipelineLog.info("schemaLinkStr:{},sqlRespMap:{}", schemaLinkStr, sqlRespMap);
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
llmResp.setSqlRespMap(sqlRespMap);
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_AUTO_COT, this);
|
||||
}
|
||||
}
|
||||
@@ -32,7 +32,7 @@ public class PythonLLMProxy implements LLMProxy {
|
||||
public LLMResp text2sql(LLMReq llmReq) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, llmReq:{}", llmReq);
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
keyPipelineLog.info("PythonLLMProxy llmReq:{}", llmReq);
|
||||
try {
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
|
||||
@@ -47,7 +47,7 @@ public class PythonLLMProxy implements LLMProxy {
|
||||
LLMResp llmResp = responseEntity.getBody();
|
||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||
System.currentTimeMillis() - startTime, url, entity, llmResp);
|
||||
keyPipelineLog.info("LLMResp:{}", llmResp);
|
||||
keyPipelineLog.info("PythonLLMProxy llmResp:{}", llmResp);
|
||||
|
||||
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));
|
||||
|
||||
@@ -23,7 +23,7 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.retriever sqlExamples and generate exampleListPool
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
@@ -37,10 +37,10 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
linkingPromptPool.parallelStream().forEach(
|
||||
linkingPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one reqPrompt:{}", prompt.toSystemMessage());
|
||||
Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = linkingResult.content().text();
|
||||
keyPipelineLog.info("step one model response:{}", result);
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one modelResp:{}", result);
|
||||
linkingResults.add(OutputFormat.getSchemaLink(result));
|
||||
}
|
||||
);
|
||||
@@ -51,15 +51,14 @@ public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
|
||||
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
|
||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("step two request prompt:{}", linkingPrompt.toSystemMessage());
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two reqPrompt:{}", linkingPrompt.toSystemMessage());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
||||
String result = sqlResult.content().text();
|
||||
keyPipelineLog.info("step two model response:{}", result);
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two modelResp:{}", result);
|
||||
sqlTaskPool.add(result);
|
||||
});
|
||||
//4.format response.
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
|
||||
keyPipelineLog.info("linkingMap:{} sqlMap:{}", linkingMap, sqlMapPair.getRight());
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
|
||||
@@ -1,58 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenType;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TwoPassSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
keyPipelineLog.info("llmReq:{}", llmReq);
|
||||
List<Map<String, String>> sqlExamples = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
optimizationConfig.getText2sqlExampleNum());
|
||||
|
||||
String linkingPromptStr = promptGenerator.generateLinkingPrompt(llmReq, sqlExamples);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
keyPipelineLog.info("step one model response:{}", response.content().text());
|
||||
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());
|
||||
String generateSqlPrompt = promptGenerator.generateSqlPrompt(llmReq, schemaLinkStr, sqlExamples);
|
||||
|
||||
Prompt sqlPrompt = PromptTemplate.from(JsonUtil.toString(generateSqlPrompt)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("step two request prompt:{}", sqlPrompt.toSystemMessage());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(sqlPrompt.toSystemMessage());
|
||||
String result = sqlResult.content().text();
|
||||
keyPipelineLog.info("step two model response:{}", result);
|
||||
Map<String, Double> sqlMap = new HashMap<>();
|
||||
sqlMap.put(result, 1D);
|
||||
keyPipelineLog.info("schemaLinkStr:{},sqlMap:{}", schemaLinkStr, sqlMap);
|
||||
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(sqlExamples, sqlMap));
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(SqlGenType.TWO_PASS_AUTO_COT, this);
|
||||
}
|
||||
}
|
||||
@@ -68,16 +68,9 @@ public class LLMReq {
|
||||
}
|
||||
|
||||
public enum SqlGenType {
|
||||
|
||||
ONE_PASS_AUTO_COT("1_pass_auto_cot"),
|
||||
|
||||
ONE_PASS_AUTO_COT_SELF_CONSISTENCY("1_pass_auto_cot_self_consistency"),
|
||||
|
||||
TWO_PASS_AUTO_COT("2_pass_auto_cot"),
|
||||
|
||||
TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
|
||||
|
||||
|
||||
private String name;
|
||||
|
||||
SqlGenType(String name) {
|
||||
|
||||
@@ -16,6 +16,7 @@ public class OptimizationConfig {
|
||||
|
||||
@Value("${s2.one.detection.size:8}")
|
||||
private Integer oneDetectionSize;
|
||||
|
||||
@Value("${s2.one.detection.max.size:20}")
|
||||
private Integer oneDetectionMaxSize;
|
||||
|
||||
@@ -67,19 +68,19 @@ public class OptimizationConfig {
|
||||
@Value("${s2.parser.linking.value.switch:true}")
|
||||
private boolean useLinkingValueSwitch;
|
||||
|
||||
@Value("${s2.parser.generation:TWO_PASS_AUTO_COT}")
|
||||
@Value("${s2.parser.strategy:TWO_PASS_AUTO_COT_SELF_CONSISTENCY}")
|
||||
private LLMReq.SqlGenType sqlGenType;
|
||||
|
||||
@Value("${s2.parser.use.switch:true}")
|
||||
private boolean useS2SqlSwitch;
|
||||
|
||||
@Value("${s2.parser.exemplar-recall.num:15}")
|
||||
@Value("${s2.parser.exemplar-recall.number:15}")
|
||||
private int text2sqlExampleNum;
|
||||
|
||||
@Value("${s2.parser.few-shot.num:10}")
|
||||
@Value("${s2.parser.few-shot.number:5}")
|
||||
private int text2sqlFewShotsNum;
|
||||
|
||||
@Value("${s2.parser.self-consistency.num:5}")
|
||||
@Value("${s2.parser.self-consistency.number:5}")
|
||||
private int text2sqlSelfConsistencyNum;
|
||||
|
||||
@Value("${s2.parser.show-count:3}")
|
||||
@@ -89,83 +90,83 @@ public class OptimizationConfig {
|
||||
private SysParameterService sysParameterService;
|
||||
|
||||
public Integer getOneDetectionSize() {
|
||||
return convertValue("one.detection.size", Integer.class, oneDetectionSize);
|
||||
return convertValue("s2.one.detection.size", Integer.class, oneDetectionSize);
|
||||
}
|
||||
|
||||
public Integer getOneDetectionMaxSize() {
|
||||
return convertValue("one.detection.max.size", Integer.class, oneDetectionMaxSize);
|
||||
return convertValue("s2.one.detection.max.size", Integer.class, oneDetectionMaxSize);
|
||||
}
|
||||
|
||||
public Double getMetricDimensionMinThresholdConfig() {
|
||||
return convertValue("metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
|
||||
return convertValue("s2.metric.dimension.min.threshold", Double.class, metricDimensionMinThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getMetricDimensionThresholdConfig() {
|
||||
return convertValue("metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
|
||||
return convertValue("s2.metric.dimension.threshold", Double.class, metricDimensionThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getDimensionValueMinThresholdConfig() {
|
||||
return convertValue("dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
|
||||
return convertValue("s2.dimension.value.min.threshold", Double.class, dimensionValueMinThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getDimensionValueThresholdConfig() {
|
||||
return convertValue("dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
|
||||
return convertValue("s2.dimension.value.threshold", Double.class, dimensionValueThresholdConfig);
|
||||
}
|
||||
|
||||
public Double getLongTextThreshold() {
|
||||
return convertValue("long.text.threshold", Double.class, longTextThreshold);
|
||||
return convertValue("s2.long.text.threshold", Double.class, longTextThreshold);
|
||||
}
|
||||
|
||||
public Double getShortTextThreshold() {
|
||||
return convertValue("short.text.threshold", Double.class, shortTextThreshold);
|
||||
return convertValue("s2.short.text.threshold", Double.class, shortTextThreshold);
|
||||
}
|
||||
|
||||
public Integer getQueryTextLengthThreshold() {
|
||||
return convertValue("query.text.length.threshold", Integer.class, queryTextLengthThreshold);
|
||||
}
|
||||
|
||||
public boolean isUseS2SqlSwitch() {
|
||||
return convertValue("use.s2SQL.switch", Boolean.class, useS2SqlSwitch);
|
||||
return convertValue("s2.query.text.length.threshold", Integer.class, queryTextLengthThreshold);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperWordMin() {
|
||||
return convertValue("embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
|
||||
return convertValue("s2.embedding.mapper.word.min", Integer.class, embeddingMapperWordMin);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperWordMax() {
|
||||
return convertValue("embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
|
||||
return convertValue("s2.embedding.mapper.word.max", Integer.class, embeddingMapperWordMax);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperBatch() {
|
||||
return convertValue("embedding.mapper.batch", Integer.class, embeddingMapperBatch);
|
||||
return convertValue("s2.embedding.mapper.batch", Integer.class, embeddingMapperBatch);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperNumber() {
|
||||
return convertValue("embedding.mapper.number", Integer.class, embeddingMapperNumber);
|
||||
return convertValue("s2.embedding.mapper.number", Integer.class, embeddingMapperNumber);
|
||||
}
|
||||
|
||||
public Integer getEmbeddingMapperRoundNumber() {
|
||||
return convertValue("embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||
return convertValue("s2.embedding.mapper.round.number", Integer.class, embeddingMapperRoundNumber);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperMinThreshold() {
|
||||
return convertValue("embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
|
||||
return convertValue("s2.embedding.mapper.min.threshold", Double.class, embeddingMapperMinThreshold);
|
||||
}
|
||||
|
||||
public Double getEmbeddingMapperThreshold() {
|
||||
return convertValue("embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
|
||||
return convertValue("s2.embedding.mapper.threshold", Double.class, embeddingMapperThreshold);
|
||||
}
|
||||
|
||||
public boolean isUseS2SqlSwitch() {
|
||||
return convertValue("s2.parser.use.switch", Boolean.class, useS2SqlSwitch);
|
||||
}
|
||||
|
||||
public boolean isUseLinkingValueSwitch() {
|
||||
return convertValue("s2SQL.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||
return convertValue("s2.parser.linking.value.switch", Boolean.class, useLinkingValueSwitch);
|
||||
}
|
||||
|
||||
public LLMReq.SqlGenType getSqlGenType() {
|
||||
return convertValue("s2SQL.generation", LLMReq.SqlGenType.class, sqlGenType);
|
||||
return convertValue("s2.parser.strategy", LLMReq.SqlGenType.class, sqlGenType);
|
||||
}
|
||||
|
||||
public Integer getParseShowCount() {
|
||||
return convertValue("parse.show.count", Integer.class, parseShowCount);
|
||||
return convertValue("s2.parse.show-count", Integer.class, parseShowCount);
|
||||
}
|
||||
|
||||
public <T> T convertValue(String paramName, Class<T> targetType, T defaultValue) {
|
||||
|
||||
Reference in New Issue
Block a user