mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 20:25:12 +00:00
[feature][chat]Refactor chat model config related codes.#1739
This commit is contained in:
@@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
@@ -28,7 +27,7 @@ public class QueryNLReq {
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private String customPrompt;
|
||||
private List<Text2SQLExemplar> dynamicExemplars = Lists.newArrayList();
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
@@ -54,7 +53,7 @@ public class ChatQueryContext {
|
||||
private ChatWorkflowState chatWorkflowState;
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private String customPrompt;
|
||||
private List<Text2SQLExemplar> dynamicExemplars;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
|
||||
@@ -14,40 +14,40 @@ public class ParserConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter PARSER_STRATEGY_TYPE =
|
||||
new Parameter("s2.parser.s2sql.strategy", "ONE_PASS_SELF_CONSISTENCY", "LLM解析生成S2SQL策略",
|
||||
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql", "list", "Parser相关配置",
|
||||
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql", "list", "语义解析配置",
|
||||
Lists.newArrayList("ONE_PASS_SELF_CONSISTENCY"));
|
||||
|
||||
public static final Parameter PARSER_LINKING_VALUE_ENABLE =
|
||||
new Parameter("s2.parser.linking.value.enable", "true", "是否将Mapper探测识别到的维度值提供给大模型",
|
||||
"为了数据安全考虑, 这里可进行开关选择", "bool", "Parser相关配置");
|
||||
"为了数据安全考虑, 这里可进行开关选择", "bool", "语义解析配置");
|
||||
|
||||
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD =
|
||||
new Parameter("s2.parser.text.length.threshold", "10", "用户输入文本长短阈值", "文本超过该阈值为长文本",
|
||||
"number", "Parser相关配置");
|
||||
"number", "语义解析配置");
|
||||
|
||||
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT =
|
||||
new Parameter("s2.parser.text.threshold.short", "0.5", "短文本匹配阈值",
|
||||
"由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,"
|
||||
+ "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number", "Parser相关配置");
|
||||
"number", "语义解析配置");
|
||||
|
||||
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG =
|
||||
new Parameter("s2.parser.text.threshold.long", "0.8", "长文本匹配阈值",
|
||||
"如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "Parser相关配置");
|
||||
"如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "语义解析配置");
|
||||
|
||||
public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER = new Parameter(
|
||||
"s2.parser.exemplar-recall.number", "10", "exemplar召回个数", "", "number", "Parser相关配置");
|
||||
"s2.parser.exemplar-recall.number", "10", "exemplar召回个数", "", "number", "语义解析配置");
|
||||
|
||||
public static final Parameter PARSER_FEW_SHOT_NUMBER =
|
||||
new Parameter("s2.parser.few-shot.number", "3", "few-shot样例个数", "样例越多效果可能越好,但token消耗越大",
|
||||
"number", "Parser相关配置");
|
||||
"number", "语义解析配置");
|
||||
|
||||
public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER =
|
||||
new Parameter("s2.parser.self-consistency.number", "1", "self-consistency执行个数",
|
||||
"执行越多效果可能越好,但token消耗越大", "number", "Parser相关配置");
|
||||
"执行越多效果可能越好,但token消耗越大", "number", "语义解析配置");
|
||||
|
||||
public static final Parameter PARSER_SHOW_COUNT = new Parameter("s2.parser.show.count", "3",
|
||||
"解析结果展示个数", "前端展示的解析个数", "number", "Parser相关配置");
|
||||
public static final Parameter PARSER_SHOW_COUNT =
|
||||
new Parameter("s2.parser.show.count", "3", "解析结果展示个数", "前端展示的解析个数", "number", "语义解析配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
|
||||
@@ -75,7 +75,7 @@ public class LLMRequestService {
|
||||
llmReq.setSqlGenType(
|
||||
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||
llmReq.setModelConfig(queryCtx.getModelConfig());
|
||||
llmReq.setPromptConfig(queryCtx.getPromptConfig());
|
||||
llmReq.setCustomPrompt(queryCtx.getCustomPrompt());
|
||||
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
|
||||
|
||||
return llmReq;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
@@ -112,10 +111,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
variable.put("information", sideInformation);
|
||||
|
||||
// use custom prompt template if provided.
|
||||
PromptConfig promptConfig = llmReq.getPromptConfig();
|
||||
String promptTemplate = INSTRUCTION;
|
||||
if (promptConfig != null && StringUtils.isNotBlank(promptConfig.getPromptTemplate())) {
|
||||
promptTemplate = promptConfig.getPromptTemplate();
|
||||
if (StringUtils.isNotBlank(llmReq.getCustomPrompt())) {
|
||||
promptTemplate = llmReq.getCustomPrompt();
|
||||
}
|
||||
return PromptTemplate.from(promptTemplate).apply(variable);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
@@ -23,7 +22,7 @@ public class LLMReq {
|
||||
private String priorExts;
|
||||
private SqlGenType sqlGenType;
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private String customPrompt;
|
||||
private List<Text2SQLExemplar> dynamicExemplars;
|
||||
|
||||
@Data
|
||||
|
||||
@@ -254,7 +254,7 @@ public class DataSetServiceImpl extends ServiceImpl<DataSetDOMapper, DataSetDO>
|
||||
|
||||
@Override
|
||||
public Map<Long, List<Long>> getModelIdToDataSetIds() {
|
||||
return getModelIdToDataSetIds(Lists.newArrayList(), User.getFakeUser());
|
||||
return getModelIdToDataSetIds(Lists.newArrayList(), User.getDefaultUser());
|
||||
}
|
||||
|
||||
private void conflictCheck(DataSetResp dataSetResp) {
|
||||
|
||||
@@ -70,7 +70,7 @@ public class RetrieveServiceImpl implements RetrieveService {
|
||||
List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics();
|
||||
final Map<Long, String> dataSetIdToName = semanticSchemaDb.getDataSetIdToName();
|
||||
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(
|
||||
new ArrayList<>(dataSetIdToName.keySet()), User.getFakeUser());
|
||||
new ArrayList<>(dataSetIdToName.keySet()), User.getDefaultUser());
|
||||
// 2.detect by segment
|
||||
List<S2Term> originals = knowledgeBaseService.getTerms(queryText, modelIdToDataSetIds);
|
||||
log.debug("hanlp parse result: {}", originals);
|
||||
|
||||
@@ -162,7 +162,7 @@ public class DictUtils {
|
||||
dictItemResp.setBizName(dimension.getBizName());
|
||||
}
|
||||
if (TypeEnums.TAG.equals(TypeEnums.valueOf(dictConfDO.getType()))) {
|
||||
TagResp tagResp = tagMetaService.getTag(dictConfDO.getItemId(), User.getFakeUser());
|
||||
TagResp tagResp = tagMetaService.getTag(dictConfDO.getItemId(), User.getDefaultUser());
|
||||
dictItemResp.setModelId(tagResp.getModelId());
|
||||
dictItemResp.setBizName(tagResp.getBizName());
|
||||
}
|
||||
|
||||
@@ -43,7 +43,7 @@ public class MetricServiceImplTest {
|
||||
MetricReq metricReq = buildMetricReq();
|
||||
when(modelService.getModel(metricReq.getModelId())).thenReturn(mockModelResp());
|
||||
when(modelService.getModelByDomainIds(any())).thenReturn(Lists.newArrayList());
|
||||
MetricResp actualMetricResp = metricService.createMetric(metricReq, User.getFakeUser());
|
||||
MetricResp actualMetricResp = metricService.createMetric(metricReq, User.getDefaultUser());
|
||||
MetricResp expectedMetricResp = buildExpectedMetricResp();
|
||||
Assertions.assertEquals(expectedMetricResp, actualMetricResp);
|
||||
}
|
||||
@@ -58,7 +58,7 @@ public class MetricServiceImplTest {
|
||||
when(modelService.getModelByDomainIds(any())).thenReturn(Lists.newArrayList());
|
||||
MetricDO metricDO = MetricConverter.convert2MetricDO(buildMetricReq());
|
||||
when(metricRepository.getMetricById(metricDO.getId())).thenReturn(metricDO);
|
||||
MetricResp actualMetricResp = metricService.updateMetric(metricReq, User.getFakeUser());
|
||||
MetricResp actualMetricResp = metricService.updateMetric(metricReq, User.getDefaultUser());
|
||||
MetricResp expectedMetricResp = buildExpectedMetricResp();
|
||||
Assertions.assertEquals(expectedMetricResp, actualMetricResp);
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ class ModelServiceImplTest {
|
||||
void createModel() throws Exception {
|
||||
ModelRepository modelRepository = Mockito.mock(ModelRepository.class);
|
||||
ModelService modelService = mockModelService(modelRepository);
|
||||
ModelResp actualModelResp = modelService.createModel(mockModelReq(), User.getFakeUser());
|
||||
ModelResp actualModelResp = modelService.createModel(mockModelReq(), User.getDefaultUser());
|
||||
ModelResp expectedModelResp = buildExpectedModelResp();
|
||||
Assertions.assertEquals(expectedModelResp, actualModelResp);
|
||||
}
|
||||
@@ -44,9 +44,9 @@ class ModelServiceImplTest {
|
||||
ModelRepository modelRepository = Mockito.mock(ModelRepository.class);
|
||||
ModelService modelService = mockModelService(modelRepository);
|
||||
ModelReq modelReq = mockModelReq_update();
|
||||
ModelDO modelDO = ModelConverter.convert(mockModelReq(), User.getFakeUser());
|
||||
ModelDO modelDO = ModelConverter.convert(mockModelReq(), User.getDefaultUser());
|
||||
when(modelRepository.getModelById(modelReq.getId())).thenReturn(modelDO);
|
||||
User user = User.getFakeUser();
|
||||
User user = User.getDefaultUser();
|
||||
user.setName("alice");
|
||||
ModelResp actualModelResp = modelService.updateModel(modelReq, user);
|
||||
ModelResp expectedModelResp = buildExpectedModelResp_update();
|
||||
@@ -60,9 +60,9 @@ class ModelServiceImplTest {
|
||||
ModelRepository modelRepository = Mockito.mock(ModelRepository.class);
|
||||
ModelService modelService = mockModelService(modelRepository);
|
||||
ModelReq modelReq = mockModelReq_updateAdmin();
|
||||
ModelDO modelDO = ModelConverter.convert(mockModelReq(), User.getFakeUser());
|
||||
ModelDO modelDO = ModelConverter.convert(mockModelReq(), User.getDefaultUser());
|
||||
when(modelRepository.getModelById(modelReq.getId())).thenReturn(modelDO);
|
||||
ModelResp actualModelResp = modelService.updateModel(modelReq, User.getFakeUser());
|
||||
ModelResp actualModelResp = modelService.updateModel(modelReq, User.getDefaultUser());
|
||||
ModelResp expectedModelResp = buildExpectedModelResp();
|
||||
Assertions.assertEquals(expectedModelResp, actualModelResp);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user