mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(parser) Add json format to LLM request for performance improvement (#2352)
This commit is contained in:
@@ -28,6 +28,8 @@ public class ChatModelConfig implements Serializable {
|
|||||||
private Boolean logRequests = false;
|
private Boolean logRequests = false;
|
||||||
private Boolean logResponses = false;
|
private Boolean logResponses = false;
|
||||||
private Boolean enableSearch = false;
|
private Boolean enableSearch = false;
|
||||||
|
private Boolean jsonFormat = false;
|
||||||
|
private String jsonFormatType = "json_schema";
|
||||||
|
|
||||||
public String keyDecrypt() {
|
public String keyDecrypt() {
|
||||||
return AESEncryptionUtil.aesDecryptECB(getApiKey());
|
return AESEncryptionUtil.aesDecryptECB(getApiKey());
|
||||||
|
|||||||
@@ -22,13 +22,17 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) {
|
||||||
return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
OpenAiChatModel.OpenAiChatModelBuilder openAiChatModelBuilder = OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl())
|
||||||
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
|
.modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt())
|
||||||
.apiVersion(modelConfig.getApiVersion()).temperature(modelConfig.getTemperature())
|
.apiVersion(modelConfig.getApiVersion()).temperature(modelConfig.getTemperature())
|
||||||
.topP(modelConfig.getTopP()).maxRetries(modelConfig.getMaxRetries())
|
.topP(modelConfig.getTopP()).maxRetries(modelConfig.getMaxRetries())
|
||||||
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
.timeout(Duration.ofSeconds(modelConfig.getTimeOut()))
|
||||||
.logRequests(modelConfig.getLogRequests())
|
.logRequests(modelConfig.getLogRequests())
|
||||||
.logResponses(modelConfig.getLogResponses()).build();
|
.logResponses(modelConfig.getLogResponses());
|
||||||
|
if (modelConfig.getJsonFormat()) {
|
||||||
|
openAiChatModelBuilder.strictJsonSchema(true).responseFormat(modelConfig.getJsonFormatType());
|
||||||
|
}
|
||||||
|
return openAiChatModelBuilder.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -57,6 +57,10 @@ public class ParserConfig extends ParameterConfig {
|
|||||||
new Parameter("s2.parser.field.count.threshold", "0", "语义字段个数阈值",
|
new Parameter("s2.parser.field.count.threshold", "0", "语义字段个数阈值",
|
||||||
"如果映射字段小于该阈值,则将数据集所有字段输入LLM", "number", "语义解析配置");
|
"如果映射字段小于该阈值,则将数据集所有字段输入LLM", "number", "语义解析配置");
|
||||||
|
|
||||||
|
public static final Parameter PARSER_FORMAT_JSON_TYPE =
|
||||||
|
new Parameter("s2.parser.format.json-type", "", "请求llm返回json格式,默认不设置json格式",
|
||||||
|
"选项:json_schema或者json_object", "string", "语义解析配置");
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Parameter> getSysParameters() {
|
public List<Parameter> getSysParameters() {
|
||||||
return Lists.newArrayList(PARSER_LINKING_VALUE_ENABLE, PARSER_RULE_CORRECTOR_ENABLE,
|
return Lists.newArrayList(PARSER_LINKING_VALUE_ENABLE, PARSER_RULE_CORRECTOR_ENABLE,
|
||||||
|
|||||||
@@ -2,9 +2,11 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
|||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||||
|
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
@@ -14,9 +16,11 @@ import dev.langchain4j.model.output.structured.Description;
|
|||||||
import dev.langchain4j.service.AiServices;
|
import dev.langchain4j.service.AiServices;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang.StringUtils;
|
||||||
import org.apache.commons.lang3.tuple.Pair;
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -24,6 +28,8 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
import java.util.concurrent.ConcurrentHashMap;
|
||||||
|
|
||||||
|
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FORMAT_JSON_TYPE;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||||
@@ -31,6 +37,10 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
public static final String APP_KEY = "S2SQL_PARSER";
|
public static final String APP_KEY = "S2SQL_PARSER";
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private ParserConfig parserConfig;
|
||||||
|
|
||||||
public static final String INSTRUCTION =
|
public static final String INSTRUCTION =
|
||||||
"#Role: You are a data analyst experienced in SQL languages."
|
"#Role: You are a data analyst experienced in SQL languages."
|
||||||
+ "\n#Task: You will be provided with a natural language question asked by users,"
|
+ "\n#Task: You will be provided with a natural language question asked by users,"
|
||||||
@@ -74,7 +84,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
|
|
||||||
// 2.generate sql generation prompt for each self-consistency inference
|
// 2.generate sql generation prompt for each self-consistency inference
|
||||||
ChatApp chatApp = llmReq.getChatAppConfig().get(APP_KEY);
|
ChatApp chatApp = llmReq.getChatAppConfig().get(APP_KEY);
|
||||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatApp.getChatModelConfig());
|
ChatModelConfig chatModelConfig = chatApp.getChatModelConfig();
|
||||||
|
if (!StringUtils.isBlank(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE))) {
|
||||||
|
chatModelConfig.setJsonFormat(true);
|
||||||
|
chatModelConfig.setJsonFormatType(parserConfig.getParameterValue(PARSER_FORMAT_JSON_TYPE));
|
||||||
|
}
|
||||||
|
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatModelConfig);
|
||||||
SemanticSqlExtractor extractor =
|
SemanticSqlExtractor extractor =
|
||||||
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user