[feature][chat]Refactor chat model config related codes.#1739

This commit is contained in:
jerryjzhang
2024-10-09 17:27:07 +08:00
parent 60b0a1a1a1
commit 248f4f83f6
53 changed files with 275 additions and 251 deletions

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
@@ -54,7 +53,7 @@ public class ChatQueryContext {
private ChatWorkflowState chatWorkflowState;
private QueryDataType queryDataType = QueryDataType.ALL;
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
private String customPrompt;
private List<Text2SQLExemplar> dynamicExemplars;
public List<SemanticQuery> getCandidateQueries() {

View File

@@ -14,40 +14,40 @@ public class ParserConfig extends ParameterConfig {
public static final Parameter PARSER_STRATEGY_TYPE =
new Parameter("s2.parser.s2sql.strategy", "ONE_PASS_SELF_CONSISTENCY", "LLM解析生成S2SQL策略",
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql", "list", "Parser相关配置",
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql", "list", "语义解析配置",
Lists.newArrayList("ONE_PASS_SELF_CONSISTENCY"));
public static final Parameter PARSER_LINKING_VALUE_ENABLE =
new Parameter("s2.parser.linking.value.enable", "true", "是否将Mapper探测识别到的维度值提供给大模型",
"为了数据安全考虑, 这里可进行开关选择", "bool", "Parser相关配置");
"为了数据安全考虑, 这里可进行开关选择", "bool", "语义解析配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD =
new Parameter("s2.parser.text.length.threshold", "10", "用户输入文本长短阈值", "文本超过该阈值为长文本",
"number", "Parser相关配置");
"number", "语义解析配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT =
new Parameter("s2.parser.text.threshold.short", "0.5", "短文本匹配阈值",
"由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,"
+ "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置");
"number", "语义解析配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG =
new Parameter("s2.parser.text.threshold.long", "0.8", "长文本匹配阈值",
"如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "Parser相关配置");
"如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "语义解析配置");
public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER = new Parameter(
"s2.parser.exemplar-recall.number", "10", "exemplar召回个数", "", "number", "Parser相关配置");
"s2.parser.exemplar-recall.number", "10", "exemplar召回个数", "", "number", "语义解析配置");
public static final Parameter PARSER_FEW_SHOT_NUMBER =
new Parameter("s2.parser.few-shot.number", "3", "few-shot样例个数", "样例越多效果可能越好但token消耗越大",
"number", "Parser相关配置");
"number", "语义解析配置");
public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER =
new Parameter("s2.parser.self-consistency.number", "1", "self-consistency执行个数",
"执行越多效果可能越好但token消耗越大", "number", "Parser相关配置");
"执行越多效果可能越好但token消耗越大", "number", "语义解析配置");
public static final Parameter PARSER_SHOW_COUNT = new Parameter("s2.parser.show.count", "3",
"解析结果展示个数", "前端展示的解析个数", "number", "Parser相关配置");
public static final Parameter PARSER_SHOW_COUNT =
new Parameter("s2.parser.show.count", "3", "解析结果展示个数", "前端展示的解析个数", "number", "语义解析配置");
@Override
public List<Parameter> getSysParameters() {

View File

@@ -75,7 +75,7 @@ public class LLMRequestService {
llmReq.setSqlGenType(
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setModelConfig(queryCtx.getModelConfig());
llmReq.setPromptConfig(queryCtx.getPromptConfig());
llmReq.setCustomPrompt(queryCtx.getCustomPrompt());
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
return llmReq;

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
@@ -112,10 +111,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
variable.put("information", sideInformation);
// use custom prompt template if provided.
PromptConfig promptConfig = llmReq.getPromptConfig();
String promptTemplate = INSTRUCTION;
if (promptConfig != null && StringUtils.isNotBlank(promptConfig.getPromptTemplate())) {
promptTemplate = promptConfig.getPromptTemplate();
if (StringUtils.isNotBlank(llmReq.getCustomPrompt())) {
promptTemplate = llmReq.getCustomPrompt();
}
return PromptTemplate.from(promptTemplate).apply(variable);
}

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
@@ -23,7 +22,7 @@ public class LLMReq {
private String priorExts;
private SqlGenType sqlGenType;
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
private String customPrompt;
private List<Text2SQLExemplar> dynamicExemplars;
@Data