mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) 优化提示工程、重试机制 (#1658)
This commit is contained in:
@@ -73,6 +73,11 @@ public class LLMSqlParser implements SemanticParser {
|
||||
} catch (Exception e) {
|
||||
log.error("currentRetryRound:{}, runText2SQL failed", currentRetry, e);
|
||||
}
|
||||
Double temperature = llmReq.getModelConfig().getTemperature();
|
||||
if (temperature == 0) {
|
||||
// 报错时增加随机性,减少无效重试
|
||||
llmReq.getModelConfig().setTemperature(0.5);
|
||||
}
|
||||
currentRetry++;
|
||||
}
|
||||
if (MapUtils.isEmpty(sqlRespMap)) {
|
||||
|
||||
@@ -5,11 +5,12 @@ import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.output.structured.Description;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -24,6 +25,19 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Data
|
||||
static class SemanticSql {
|
||||
@Description("thought or remarks to tell users about the sql, make it short.")
|
||||
private String thought;
|
||||
|
||||
@Description("sql to generate")
|
||||
private String sql;
|
||||
}
|
||||
|
||||
interface SemanticSqlExtractor {
|
||||
SemanticSql generateSemanticSql(String text);
|
||||
}
|
||||
|
||||
private static final String INSTRUCTION =
|
||||
""
|
||||
+ "\n#Role: You are a data analyst experienced in SQL languages."
|
||||
@@ -36,9 +50,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
+ "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||
+ "4.DO NOT calculate date range using functions."
|
||||
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
||||
+ "6.ONLY respond with the converted SQL statement."
|
||||
+ "\n#Exemplars:\n{{exemplar}}"
|
||||
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}},SQL:";
|
||||
+ "\n#Question:"
|
||||
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
@@ -55,6 +69,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
Prompt prompt = generatePrompt(llmReq, llmResp);
|
||||
prompt2Exemplar.put(prompt, exemplars);
|
||||
}
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
|
||||
SemanticSqlExtractor extractor =
|
||||
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||
|
||||
// 3.perform multiple self-consistency inferences parallelly
|
||||
Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>();
|
||||
@@ -66,18 +83,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
keyPipelineLog.info(
|
||||
"OnePassSCSqlGenStrategy reqPrompt:\n{}",
|
||||
prompt.toUserMessage());
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
getChatLanguageModel(llmReq.getModelConfig());
|
||||
Response<AiMessage> response =
|
||||
chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String sqlOutput =
|
||||
StringUtils.normalizeSpace(response.content().text());
|
||||
// replace ```
|
||||
String sqlOutputFormat =
|
||||
sqlOutput.replaceAll("(?s)```sql\\s*(.*?)\\s*```", "$1").trim();
|
||||
output2Prompt.put(sqlOutputFormat, prompt);
|
||||
SemanticSql s2Sql =
|
||||
extractor.generateSemanticSql(
|
||||
prompt.toUserMessage().singleText());
|
||||
output2Prompt.put(s2Sql.getSql(), prompt);
|
||||
keyPipelineLog.info(
|
||||
"OnePassSCSqlGenStrategy modelResp:\n{}", sqlOutputFormat);
|
||||
"OnePassSCSqlGenStrategy modelResp:\n{}", s2Sql.getSql());
|
||||
});
|
||||
|
||||
// 4.format response.
|
||||
|
||||
Reference in New Issue
Block a user