[feature][chat]Refactor chat model config related codes.#1739

This commit is contained in:
jerryjzhang
2024-10-09 17:27:07 +08:00
parent 60b0a1a1a1
commit 248f4f83f6
53 changed files with 275 additions and 251 deletions

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
@@ -28,7 +27,7 @@ public class QueryNLReq {
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private QueryDataType queryDataType = QueryDataType.ALL;
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
private String customPrompt;
private List<Text2SQLExemplar> dynamicExemplars = Lists.newArrayList();
private SemanticParseInfo contextParseInfo;
}

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
@@ -54,7 +53,7 @@ public class ChatQueryContext {
private ChatWorkflowState chatWorkflowState;
private QueryDataType queryDataType = QueryDataType.ALL;
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
private String customPrompt;
private List<Text2SQLExemplar> dynamicExemplars;
public List<SemanticQuery> getCandidateQueries() {

View File

@@ -14,40 +14,40 @@ public class ParserConfig extends ParameterConfig {
public static final Parameter PARSER_STRATEGY_TYPE =
new Parameter("s2.parser.s2sql.strategy", "ONE_PASS_SELF_CONSISTENCY", "LLM解析生成S2SQL策略",
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql", "list", "Parser相关配置",
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql", "list", "语义解析配置",
Lists.newArrayList("ONE_PASS_SELF_CONSISTENCY"));
public static final Parameter PARSER_LINKING_VALUE_ENABLE =
new Parameter("s2.parser.linking.value.enable", "true", "是否将Mapper探测识别到的维度值提供给大模型",
"为了数据安全考虑, 这里可进行开关选择", "bool", "Parser相关配置");
"为了数据安全考虑, 这里可进行开关选择", "bool", "语义解析配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD =
new Parameter("s2.parser.text.length.threshold", "10", "用户输入文本长短阈值", "文本超过该阈值为长文本",
"number", "Parser相关配置");
"number", "语义解析配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT =
new Parameter("s2.parser.text.threshold.short", "0.5", "短文本匹配阈值",
"由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,"
+ "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number", "Parser相关配置");
"number", "语义解析配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG =
new Parameter("s2.parser.text.threshold.long", "0.8", "长文本匹配阈值",
"如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "Parser相关配置");
"如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "语义解析配置");
public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER = new Parameter(
"s2.parser.exemplar-recall.number", "10", "exemplar召回个数", "", "number", "Parser相关配置");
"s2.parser.exemplar-recall.number", "10", "exemplar召回个数", "", "number", "语义解析配置");
public static final Parameter PARSER_FEW_SHOT_NUMBER =
new Parameter("s2.parser.few-shot.number", "3", "few-shot样例个数", "样例越多效果可能越好但token消耗越大",
"number", "Parser相关配置");
"number", "语义解析配置");
public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER =
new Parameter("s2.parser.self-consistency.number", "1", "self-consistency执行个数",
"执行越多效果可能越好但token消耗越大", "number", "Parser相关配置");
"执行越多效果可能越好但token消耗越大", "number", "语义解析配置");
public static final Parameter PARSER_SHOW_COUNT = new Parameter("s2.parser.show.count", "3",
"解析结果展示个数", "前端展示的解析个数", "number", "Parser相关配置");
public static final Parameter PARSER_SHOW_COUNT =
new Parameter("s2.parser.show.count", "3", "解析结果展示个数", "前端展示的解析个数", "number", "语义解析配置");
@Override
public List<Parameter> getSysParameters() {

View File

@@ -75,7 +75,7 @@ public class LLMRequestService {
llmReq.setSqlGenType(
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setModelConfig(queryCtx.getModelConfig());
llmReq.setPromptConfig(queryCtx.getPromptConfig());
llmReq.setCustomPrompt(queryCtx.getCustomPrompt());
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
return llmReq;

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
@@ -112,10 +111,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
variable.put("information", sideInformation);
// use custom prompt template if provided.
PromptConfig promptConfig = llmReq.getPromptConfig();
String promptTemplate = INSTRUCTION;
if (promptConfig != null && StringUtils.isNotBlank(promptConfig.getPromptTemplate())) {
promptTemplate = promptConfig.getPromptTemplate();
if (StringUtils.isNotBlank(llmReq.getCustomPrompt())) {
promptTemplate = llmReq.getCustomPrompt();
}
return PromptTemplate.from(promptTemplate).apply(variable);
}

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
@@ -23,7 +22,7 @@ public class LLMReq {
private String priorExts;
private SqlGenType sqlGenType;
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
private String customPrompt;
private List<Text2SQLExemplar> dynamicExemplars;
@Data

View File

@@ -254,7 +254,7 @@ public class DataSetServiceImpl extends ServiceImpl<DataSetDOMapper, DataSetDO>
@Override
public Map<Long, List<Long>> getModelIdToDataSetIds() {
return getModelIdToDataSetIds(Lists.newArrayList(), User.getFakeUser());
return getModelIdToDataSetIds(Lists.newArrayList(), User.getDefaultUser());
}
private void conflictCheck(DataSetResp dataSetResp) {

View File

@@ -70,7 +70,7 @@ public class RetrieveServiceImpl implements RetrieveService {
List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics();
final Map<Long, String> dataSetIdToName = semanticSchemaDb.getDataSetIdToName();
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(
new ArrayList<>(dataSetIdToName.keySet()), User.getFakeUser());
new ArrayList<>(dataSetIdToName.keySet()), User.getDefaultUser());
// 2.detect by segment
List<S2Term> originals = knowledgeBaseService.getTerms(queryText, modelIdToDataSetIds);
log.debug("hanlp parse result: {}", originals);

View File

@@ -162,7 +162,7 @@ public class DictUtils {
dictItemResp.setBizName(dimension.getBizName());
}
if (TypeEnums.TAG.equals(TypeEnums.valueOf(dictConfDO.getType()))) {
TagResp tagResp = tagMetaService.getTag(dictConfDO.getItemId(), User.getFakeUser());
TagResp tagResp = tagMetaService.getTag(dictConfDO.getItemId(), User.getDefaultUser());
dictItemResp.setModelId(tagResp.getModelId());
dictItemResp.setBizName(tagResp.getBizName());
}

View File

@@ -43,7 +43,7 @@ public class MetricServiceImplTest {
MetricReq metricReq = buildMetricReq();
when(modelService.getModel(metricReq.getModelId())).thenReturn(mockModelResp());
when(modelService.getModelByDomainIds(any())).thenReturn(Lists.newArrayList());
MetricResp actualMetricResp = metricService.createMetric(metricReq, User.getFakeUser());
MetricResp actualMetricResp = metricService.createMetric(metricReq, User.getDefaultUser());
MetricResp expectedMetricResp = buildExpectedMetricResp();
Assertions.assertEquals(expectedMetricResp, actualMetricResp);
}
@@ -58,7 +58,7 @@ public class MetricServiceImplTest {
when(modelService.getModelByDomainIds(any())).thenReturn(Lists.newArrayList());
MetricDO metricDO = MetricConverter.convert2MetricDO(buildMetricReq());
when(metricRepository.getMetricById(metricDO.getId())).thenReturn(metricDO);
MetricResp actualMetricResp = metricService.updateMetric(metricReq, User.getFakeUser());
MetricResp actualMetricResp = metricService.updateMetric(metricReq, User.getDefaultUser());
MetricResp expectedMetricResp = buildExpectedMetricResp();
Assertions.assertEquals(expectedMetricResp, actualMetricResp);
}

View File

@@ -34,7 +34,7 @@ class ModelServiceImplTest {
void createModel() throws Exception {
ModelRepository modelRepository = Mockito.mock(ModelRepository.class);
ModelService modelService = mockModelService(modelRepository);
ModelResp actualModelResp = modelService.createModel(mockModelReq(), User.getFakeUser());
ModelResp actualModelResp = modelService.createModel(mockModelReq(), User.getDefaultUser());
ModelResp expectedModelResp = buildExpectedModelResp();
Assertions.assertEquals(expectedModelResp, actualModelResp);
}
@@ -44,9 +44,9 @@ class ModelServiceImplTest {
ModelRepository modelRepository = Mockito.mock(ModelRepository.class);
ModelService modelService = mockModelService(modelRepository);
ModelReq modelReq = mockModelReq_update();
ModelDO modelDO = ModelConverter.convert(mockModelReq(), User.getFakeUser());
ModelDO modelDO = ModelConverter.convert(mockModelReq(), User.getDefaultUser());
when(modelRepository.getModelById(modelReq.getId())).thenReturn(modelDO);
User user = User.getFakeUser();
User user = User.getDefaultUser();
user.setName("alice");
ModelResp actualModelResp = modelService.updateModel(modelReq, user);
ModelResp expectedModelResp = buildExpectedModelResp_update();
@@ -60,9 +60,9 @@ class ModelServiceImplTest {
ModelRepository modelRepository = Mockito.mock(ModelRepository.class);
ModelService modelService = mockModelService(modelRepository);
ModelReq modelReq = mockModelReq_updateAdmin();
ModelDO modelDO = ModelConverter.convert(mockModelReq(), User.getFakeUser());
ModelDO modelDO = ModelConverter.convert(mockModelReq(), User.getDefaultUser());
when(modelRepository.getModelById(modelReq.getId())).thenReturn(modelDO);
ModelResp actualModelResp = modelService.updateModel(modelReq, User.getFakeUser());
ModelResp actualModelResp = modelService.updateModel(modelReq, User.getDefaultUser());
ModelResp expectedModelResp = buildExpectedModelResp();
Assertions.assertEquals(expectedModelResp, actualModelResp);
}