mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
(improvement)(chat) 优化提示工程、重试机制 (#1658)
This commit is contained in:
@@ -73,6 +73,11 @@ public class LLMSqlParser implements SemanticParser {
|
|||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("currentRetryRound:{}, runText2SQL failed", currentRetry, e);
|
log.error("currentRetryRound:{}, runText2SQL failed", currentRetry, e);
|
||||||
}
|
}
|
||||||
|
Double temperature = llmReq.getModelConfig().getTemperature();
|
||||||
|
if (temperature == 0) {
|
||||||
|
// 报错时增加随机性,减少无效重试
|
||||||
|
llmReq.getModelConfig().setTemperature(0.5);
|
||||||
|
}
|
||||||
currentRetry++;
|
currentRetry++;
|
||||||
}
|
}
|
||||||
if (MapUtils.isEmpty(sqlRespMap)) {
|
if (MapUtils.isEmpty(sqlRespMap)) {
|
||||||
|
|||||||
@@ -5,11 +5,12 @@ import com.tencent.supersonic.common.config.PromptConfig;
|
|||||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.structured.Description;
|
||||||
|
import dev.langchain4j.service.AiServices;
|
||||||
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
@@ -24,6 +25,19 @@ import java.util.concurrent.ConcurrentHashMap;
|
|||||||
@Slf4j
|
@Slf4j
|
||||||
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||||
|
|
||||||
|
@Data
|
||||||
|
static class SemanticSql {
|
||||||
|
@Description("thought or remarks to tell users about the sql, make it short.")
|
||||||
|
private String thought;
|
||||||
|
|
||||||
|
@Description("sql to generate")
|
||||||
|
private String sql;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SemanticSqlExtractor {
|
||||||
|
SemanticSql generateSemanticSql(String text);
|
||||||
|
}
|
||||||
|
|
||||||
private static final String INSTRUCTION =
|
private static final String INSTRUCTION =
|
||||||
""
|
""
|
||||||
+ "\n#Role: You are a data analyst experienced in SQL languages."
|
+ "\n#Role: You are a data analyst experienced in SQL languages."
|
||||||
@@ -36,9 +50,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
+ "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
+ "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||||
+ "4.DO NOT calculate date range using functions."
|
+ "4.DO NOT calculate date range using functions."
|
||||||
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
||||||
+ "6.ONLY respond with the converted SQL statement."
|
|
||||||
+ "\n#Exemplars:\n{{exemplar}}"
|
+ "\n#Exemplars:\n{{exemplar}}"
|
||||||
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}},SQL:";
|
+ "\n#Question:"
|
||||||
|
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public LLMResp generate(LLMReq llmReq) {
|
public LLMResp generate(LLMReq llmReq) {
|
||||||
@@ -55,6 +69,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
Prompt prompt = generatePrompt(llmReq, llmResp);
|
Prompt prompt = generatePrompt(llmReq, llmResp);
|
||||||
prompt2Exemplar.put(prompt, exemplars);
|
prompt2Exemplar.put(prompt, exemplars);
|
||||||
}
|
}
|
||||||
|
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
|
||||||
|
SemanticSqlExtractor extractor =
|
||||||
|
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||||
|
|
||||||
// 3.perform multiple self-consistency inferences parallelly
|
// 3.perform multiple self-consistency inferences parallelly
|
||||||
Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>();
|
Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>();
|
||||||
@@ -66,18 +83,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
keyPipelineLog.info(
|
keyPipelineLog.info(
|
||||||
"OnePassSCSqlGenStrategy reqPrompt:\n{}",
|
"OnePassSCSqlGenStrategy reqPrompt:\n{}",
|
||||||
prompt.toUserMessage());
|
prompt.toUserMessage());
|
||||||
ChatLanguageModel chatLanguageModel =
|
SemanticSql s2Sql =
|
||||||
getChatLanguageModel(llmReq.getModelConfig());
|
extractor.generateSemanticSql(
|
||||||
Response<AiMessage> response =
|
prompt.toUserMessage().singleText());
|
||||||
chatLanguageModel.generate(prompt.toUserMessage());
|
output2Prompt.put(s2Sql.getSql(), prompt);
|
||||||
String sqlOutput =
|
|
||||||
StringUtils.normalizeSpace(response.content().text());
|
|
||||||
// replace ```
|
|
||||||
String sqlOutputFormat =
|
|
||||||
sqlOutput.replaceAll("(?s)```sql\\s*(.*?)\\s*```", "$1").trim();
|
|
||||||
output2Prompt.put(sqlOutputFormat, prompt);
|
|
||||||
keyPipelineLog.info(
|
keyPipelineLog.info(
|
||||||
"OnePassSCSqlGenStrategy modelResp:\n{}", sqlOutputFormat);
|
"OnePassSCSqlGenStrategy modelResp:\n{}", s2Sql.getSql());
|
||||||
});
|
});
|
||||||
|
|
||||||
// 4.format response.
|
// 4.format response.
|
||||||
|
|||||||
Reference in New Issue
Block a user