mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(parser) Add json format to LLM request for performance improvement (#2352)
This commit is contained in:
@@ -28,6 +28,8 @@ public class ChatModelConfig implements Serializable {
|
||||
private Boolean logRequests = false;
|
||||
private Boolean logResponses = false;
|
||||
private Boolean enableSearch = false;
|
||||
private Boolean jsonFormat = false;
|
||||
private String jsonFormatType = "json_schema";
|
||||
|
||||
public String keyDecrypt() {
|
||||
return AESEncryptionUtil.aesDecryptECB(getApiKey());
|
||||
|
||||
@@ -22,13 +22,17 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||
return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
OpenAiChatModel.OpenAiChatModelBuilder openAiChatModelBuilder = OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
|
||||
.apiVersion(modelConfig.getApiVersion()).temperature(modelConfig.getTemperature())
|
||||
.topP(modelConfig.getTopP()).maxRetries(modelConfig.getMaxRetries())
|
||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||
.logRequests(modelConfig.getLogRequests())
|
||||
.logResponses(modelConfig.getLogResponses()).build();
|
||||
.logResponses(modelConfig.getLogResponses());
|
||||
if (modelConfig.getJsonFormat()) {
|
||||
openAiChatModelBuilder.strictJsonSchema(true).responseFormat(modelConfig.getJsonFormatType());
|
||||
}
|
||||
return openAiChatModelBuilder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -57,6 +57,10 @@ public class ParserConfig extends ParameterConfig {
|
||||
new Parameter("s2.parser.field.count.threshold", "0", "语义字段个数阈值",
|
||||
"如果映射字段小于该阈值,则将数据集所有字段输入LLM", "number", "语义解析配置");
|
||||
|
||||
public static final Parameter PARSER_FORMAT_JSON_TYPE =
|
||||
new Parameter("s2.parser.format.json-type", "", "请求llm返回json格式,默认不设置json格式",
|
||||
"选项:json_schema或者json_object", "string", "语义解析配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(PARSER_LINKING_VALUE_ENABLE, PARSER_RULE_CORRECTOR_ENABLE,
|
||||
|
||||
@@ -2,9 +2,11 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
@@ -14,9 +16,11 @@ import dev.langchain4j.model.output.structured.Description;
|
||||
import dev.langchain4j.service.AiServices;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -24,6 +28,8 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FORMAT_JSON_TYPE;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
@@ -31,6 +37,10 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
public static final String APP_KEY = "S2SQL_PARSER";
|
||||
|
||||
@Autowired
|
||||
private ParserConfig parserConfig;
|
||||
|
||||
public static final String INSTRUCTION =
|
||||
"#Role: You are a data analyst experienced in SQL languages."
|
||||
+ "\n#Task: You will be provided with a natural language question asked by users,"
|
||||
@@ -74,7 +84,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
// 2.generate sql generation prompt for each self-consistency inference
|
||||
ChatApp chatApp = llmReq.getChatAppConfig().get(APP_KEY);
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatApp.getChatModelConfig());
|
||||
ChatModelConfig chatModelConfig = chatApp.getChatModelConfig();
|
||||
if (!StringUtils.isBlank(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE))) {
|
||||
chatModelConfig.setJsonFormat(true);
|
||||
chatModelConfig.setJsonFormatType(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE));
|
||||
}
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatModelConfig);
|
||||
SemanticSqlExtractor extractor =
|
||||
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user