(improvement)(parser) Add json format to LLM request for performance improvement (#2352)

This commit is contained in:
ChPi
2025-08-05 17:43:03 +08:00
committed by GitHub
parent 42bf355839
commit af28bc7c2a
4 changed files with 28 additions and 3 deletions

View File

@@ -28,6 +28,8 @@ public class ChatModelConfig implements Serializable {
private Boolean logRequests = false; private Boolean logRequests = false;
private Boolean logResponses = false; private Boolean logResponses = false;
private Boolean enableSearch = false; private Boolean enableSearch = false;
private Boolean jsonFormat = false;
private String jsonFormatType = "json_schema";
public String keyDecrypt() { public String keyDecrypt() {
return AESEncryptionUtil.aesDecryptECB(getApiKey()); return AESEncryptionUtil.aesDecryptECB(getApiKey());

View File

@@ -22,13 +22,17 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
@Override @Override
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl()) OpenAiChatModel.OpenAiChatModelBuilder openAiChatModelBuilder = OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt()) .modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
.apiVersion(modelConfig.getApiVersion()).temperature(modelConfig.getTemperature()) .apiVersion(modelConfig.getApiVersion()).temperature(modelConfig.getTemperature())
.topP(modelConfig.getTopP()).maxRetries(modelConfig.getMaxRetries()) .topP(modelConfig.getTopP()).maxRetries(modelConfig.getMaxRetries())
.timeout(Duration.ofSeconds(modelConfig.getTimeOut())) .timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
.logRequests(modelConfig.getLogRequests()) .logRequests(modelConfig.getLogRequests())
.logResponses(modelConfig.getLogResponses()).build(); .logResponses(modelConfig.getLogResponses());
if (modelConfig.getJsonFormat()) {
openAiChatModelBuilder.strictJsonSchema(true).responseFormat(modelConfig.getJsonFormatType());
}
return openAiChatModelBuilder.build();
} }
@Override @Override

View File

@@ -57,6 +57,10 @@ public class ParserConfig extends ParameterConfig {
new Parameter("s2.parser.field.count.threshold", "0", "语义字段个数阈值", new Parameter("s2.parser.field.count.threshold", "0", "语义字段个数阈值",
"如果映射字段小于该阈值则将数据集所有字段输入LLM", "number", "语义解析配置"); "如果映射字段小于该阈值则将数据集所有字段输入LLM", "number", "语义解析配置");
public static final Parameter PARSER_FORMAT_JSON_TYPE =
new Parameter("s2.parser.format.json-type", "", "请求llm返回json格式,默认不设置json格式",
"选项json_schema或者json_object", "string", "语义解析配置");
@Override @Override
public List<Parameter> getSysParameters() { public List<Parameter> getSysParameters() {
return Lists.newArrayList(PARSER_LINKING_VALUE_ENABLE, PARSER_RULE_CORRECTOR_ENABLE, return Lists.newArrayList(PARSER_LINKING_VALUE_ENABLE, PARSER_RULE_CORRECTOR_ENABLE,

View File

@@ -2,9 +2,11 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.AppModule; import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -14,9 +16,11 @@ import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.AiServices; import dev.langchain4j.service.AiServices;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.HashMap; import java.util.HashMap;
@@ -24,6 +28,8 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FORMAT_JSON_TYPE;
@Service @Service
@Slf4j @Slf4j
public class OnePassSCSqlGenStrategy extends SqlGenStrategy { public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
@@ -31,6 +37,10 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
public static final String APP_KEY = "S2SQL_PARSER"; public static final String APP_KEY = "S2SQL_PARSER";
@Autowired
private ParserConfig parserConfig;
public static final String INSTRUCTION = public static final String INSTRUCTION =
"#Role: You are a data analyst experienced in SQL languages." "#Role: You are a data analyst experienced in SQL languages."
+ "\n#Task: You will be provided with a natural language question asked by users," + "\n#Task: You will be provided with a natural language question asked by users,"
@@ -74,7 +84,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
// 2.generate sql generation prompt for each self-consistency inference // 2.generate sql generation prompt for each self-consistency inference
ChatApp chatApp = llmReq.getChatAppConfig().get(APP_KEY); ChatApp chatApp = llmReq.getChatAppConfig().get(APP_KEY);
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatApp.getChatModelConfig()); ChatModelConfig chatModelConfig = chatApp.getChatModelConfig();
if (!StringUtils.isBlank(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE))) {
chatModelConfig.setJsonFormat(true);
chatModelConfig.setJsonFormatType(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE));
}
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatModelConfig);
SemanticSqlExtractor extractor = SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel); AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);