mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
[feature][headless-chat]Introduce ChatApp to support more flexible chat model config.#1739
This commit is contained in:
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
@@ -13,6 +14,7 @@ import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
@Data
|
||||
@@ -26,8 +28,7 @@ public class QueryNLReq {
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private ChatModelConfig modelConfig;
|
||||
private String customPrompt;
|
||||
private Map<String, ChatApp> chatAppConfig;
|
||||
private List<Text2SQLExemplar> dynamicExemplars = Lists.newArrayList();
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
@@ -47,13 +48,14 @@ public class ChatQueryContext {
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
@JsonIgnore
|
||||
private SemanticSchema semanticSchema;
|
||||
@JsonIgnore
|
||||
private ChatWorkflowState chatWorkflowState;
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private ChatModelConfig modelConfig;
|
||||
private String customPrompt;
|
||||
@JsonIgnore
|
||||
private Map<String, ChatApp> chatAppConfig;
|
||||
@JsonIgnore
|
||||
private List<Text2SQLExemplar> dynamicExemplars;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -23,11 +25,11 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
public static final String INSTRUCTION = ""
|
||||
public static final String APP_KEY = "S2SQL_CORRECTOR";
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "\n#Role: You are a senior data engineer experienced in writing SQL."
|
||||
+ "\n#Task: Your will be provided with a user question and the SQL written by a junior engineer,"
|
||||
+ "please take a review and help correct it if necessary."
|
||||
+ "\n#Rules: "
|
||||
+ "please take a review and help correct it if necessary." + "\n#Rules: "
|
||||
+ "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`."
|
||||
+ "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard."
|
||||
+ "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
||||
@@ -36,6 +38,11 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
||||
+ "\n6.ALWAYS translate alias created by `AS` command to the same language as the `#Question`."
|
||||
+ "\n#Question:{{question}} #InputSQL:{{sql}} #Response:";
|
||||
|
||||
public LLMSqlCorrector() {
|
||||
ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("语义SQL修正")
|
||||
.description("").enable(false).build());
|
||||
}
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
static class SemanticSql {
|
||||
@@ -52,14 +59,16 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
if (!chatQueryContext.getText2SQLType().enableLLM()) {
|
||||
ChatApp chatApp = chatQueryContext.getChatAppConfig().get(APP_KEY);
|
||||
if (!chatQueryContext.getText2SQLType().enableLLM() || !chatApp.isEnable()) {
|
||||
return;
|
||||
}
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatQueryContext.getModelConfig());
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
SemanticSqlExtractor extractor =
|
||||
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||
Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo);
|
||||
Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo,
|
||||
chatApp.getPrompt());
|
||||
keyPipelineLog.info("LLMSqlCorrector reqPrompt:\n{}", prompt.text());
|
||||
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
|
||||
keyPipelineLog.info("LLMSqlCorrector modelResp:\n{}", s2Sql);
|
||||
@@ -68,12 +77,12 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
}
|
||||
|
||||
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo) {
|
||||
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo,
|
||||
String promptTemplate) {
|
||||
Map<String, Object> variable = new HashMap<>();
|
||||
variable.put("question", queryText);
|
||||
variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
|
||||
String promptTemplate = INSTRUCTION;
|
||||
return PromptTemplate.from(promptTemplate).apply(variable);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -74,8 +74,7 @@ public class LLMRequestService {
|
||||
llmReq.setTerms(getMappedTerms(queryCtx, dataSetId));
|
||||
llmReq.setSqlGenType(
|
||||
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||
llmReq.setModelConfig(queryCtx.getModelConfig());
|
||||
llmReq.setCustomPrompt(queryCtx.getCustomPrompt());
|
||||
llmReq.setChatAppConfig(queryCtx.getChatAppConfig());
|
||||
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
|
||||
|
||||
return llmReq;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
@@ -69,10 +70,12 @@ public class LLMSqlParser implements SemanticParser {
|
||||
} catch (Exception e) {
|
||||
log.error("currentRetryRound:{}, runText2SQL failed", currentRetry, e);
|
||||
}
|
||||
Double temperature = llmReq.getModelConfig().getTemperature();
|
||||
ChatModelConfig chatModelConfig = llmReq.getChatAppConfig()
|
||||
.get(OnePassSCSqlGenStrategy.APP_KEY).getChatModelConfig();
|
||||
Double temperature = chatModelConfig.getTemperature();
|
||||
if (temperature == 0) {
|
||||
// 报错时增加随机性,减少无效重试
|
||||
llmReq.getModelConfig().setTemperature(0.5);
|
||||
chatModelConfig.setTemperature(0.5);
|
||||
}
|
||||
currentRetry++;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -11,7 +13,6 @@ import dev.langchain4j.model.output.structured.Description;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@@ -24,6 +25,7 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
public static final String APP_KEY = "S2SQL_PARSER";
|
||||
public static final String INSTRUCTION = ""
|
||||
+ "\n#Role: You are a data analyst experienced in SQL languages."
|
||||
+ "\n#Task: You will be provided with a natural language question asked by users,"
|
||||
@@ -40,6 +42,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
+ "\n#Exemplars: {{exemplar}}"
|
||||
+ "\n#Question: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
|
||||
|
||||
public OnePassSCSqlGenStrategy() {
|
||||
ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("语义SQL解析")
|
||||
.description("通过大模型做语义解析生成S2SQL").enable(true).build());
|
||||
}
|
||||
|
||||
@Data
|
||||
static class SemanticSql {
|
||||
@Description("thought or remarks to tell users about the sql, make it short.")
|
||||
@@ -62,15 +69,17 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
List<List<Text2SQLExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||
|
||||
// 2.generate sql generation prompt for each self-consistency inference
|
||||
ChatApp chatApp = llmReq.getChatAppConfig().get(APP_KEY);
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatApp.getChatModelConfig());
|
||||
SemanticSqlExtractor extractor =
|
||||
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||
|
||||
Map<Prompt, List<Text2SQLExemplar>> prompt2Exemplar = new HashMap<>();
|
||||
for (List<Text2SQLExemplar> exemplars : exemplarsList) {
|
||||
llmReq.setDynamicExemplars(exemplars);
|
||||
Prompt prompt = generatePrompt(llmReq, llmResp);
|
||||
Prompt prompt = generatePrompt(llmReq, llmResp, chatApp);
|
||||
prompt2Exemplar.put(prompt, exemplars);
|
||||
}
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
|
||||
SemanticSqlExtractor extractor =
|
||||
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||
|
||||
// 3.perform multiple self-consistency inferences parallelly
|
||||
Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>();
|
||||
@@ -92,7 +101,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) {
|
||||
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp, ChatApp chatApp) {
|
||||
StringBuilder exemplars = new StringBuilder();
|
||||
for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) {
|
||||
String exemplarStr = String.format("\nQuestion:%s,Schema:%s,SideInfo:%s,SQL:%s",
|
||||
@@ -112,10 +121,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
variable.put("information", sideInformation);
|
||||
|
||||
// use custom prompt template if provided.
|
||||
String promptTemplate = INSTRUCTION;
|
||||
if (StringUtils.isNotBlank(llmReq.getCustomPrompt())) {
|
||||
promptTemplate = llmReq.getCustomPrompt();
|
||||
}
|
||||
String promptTemplate = chatApp.getPrompt();
|
||||
return PromptTemplate.from(promptTemplate).apply(variable);
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
@@ -10,6 +11,7 @@ import org.apache.commons.collections4.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -21,7 +23,7 @@ public class LLMReq {
|
||||
private String currentDate;
|
||||
private String priorExts;
|
||||
private SqlGenType sqlGenType;
|
||||
private ChatModelConfig modelConfig;
|
||||
private Map<String, ChatApp> chatAppConfig;
|
||||
private String customPrompt;
|
||||
private List<Text2SQLExemplar> dynamicExemplars;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user