(improvement)(chat) 优化提示工程、重试机制 (#1658)

This commit is contained in:
yudong
2024-09-13 09:25:55 +08:00
committed by GitHub
parent 37f12391b0
commit 0ff31ddf61
2 changed files with 31 additions and 15 deletions

View File

@@ -73,6 +73,11 @@ public class LLMSqlParser implements SemanticParser {
} catch (Exception e) { } catch (Exception e) {
log.error("currentRetryRound:{}, runText2SQL failed", currentRetry, e); log.error("currentRetryRound:{}, runText2SQL failed", currentRetry, e);
} }
Double temperature = llmReq.getModelConfig().getTemperature();
if (temperature == 0) {
// 报错时增加随机性,减少无效重试
llmReq.getModelConfig().setTemperature(0.5);
}
currentRetry++; currentRetry++;
} }
if (MapUtils.isEmpty(sqlRespMap)) { if (MapUtils.isEmpty(sqlRespMap)) {

View File

@@ -5,11 +5,12 @@ import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.AiServices;
import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
@@ -24,6 +25,19 @@ import java.util.concurrent.ConcurrentHashMap;
@Slf4j @Slf4j
public class OnePassSCSqlGenStrategy extends SqlGenStrategy { public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
@Data
static class SemanticSql {
@Description("thought or remarks to tell users about the sql, make it short.")
private String thought;
@Description("sql to generate")
private String sql;
}
interface SemanticSqlExtractor {
SemanticSql generateSemanticSql(String text);
}
private static final String INSTRUCTION = private static final String INSTRUCTION =
"" ""
+ "\n#Role: You are a data analyst experienced in SQL languages." + "\n#Role: You are a data analyst experienced in SQL languages."
@@ -36,9 +50,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
+ "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." + "3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "4.DO NOT calculate date range using functions." + "4.DO NOT calculate date range using functions."
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + "5.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "6.ONLY respond with the converted SQL statement."
+ "\n#Exemplars:\n{{exemplar}}" + "\n#Exemplars:\n{{exemplar}}"
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}},SQL:"; + "\n#Question:"
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
@Override @Override
public LLMResp generate(LLMReq llmReq) { public LLMResp generate(LLMReq llmReq) {
@@ -55,6 +69,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
Prompt prompt = generatePrompt(llmReq, llmResp); Prompt prompt = generatePrompt(llmReq, llmResp);
prompt2Exemplar.put(prompt, exemplars); prompt2Exemplar.put(prompt, exemplars);
} }
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
// 3.perform multiple self-consistency inferences parallelly // 3.perform multiple self-consistency inferences parallelly
Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>(); Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>();
@@ -66,18 +83,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
keyPipelineLog.info( keyPipelineLog.info(
"OnePassSCSqlGenStrategy reqPrompt:\n{}", "OnePassSCSqlGenStrategy reqPrompt:\n{}",
prompt.toUserMessage()); prompt.toUserMessage());
ChatLanguageModel chatLanguageModel = SemanticSql s2Sql =
getChatLanguageModel(llmReq.getModelConfig()); extractor.generateSemanticSql(
Response<AiMessage> response = prompt.toUserMessage().singleText());
chatLanguageModel.generate(prompt.toUserMessage()); output2Prompt.put(s2Sql.getSql(), prompt);
String sqlOutput =
StringUtils.normalizeSpace(response.content().text());
// replace ```
String sqlOutputFormat =
sqlOutput.replaceAll("(?s)```sql\\s*(.*?)\\s*```", "$1").trim();
output2Prompt.put(sqlOutputFormat, prompt);
keyPipelineLog.info( keyPipelineLog.info(
"OnePassSCSqlGenStrategy modelResp:\n{}", sqlOutputFormat); "OnePassSCSqlGenStrategy modelResp:\n{}", s2Sql.getSql());
}); });
// 4.format response. // 4.format response.