(improvement)(Chat) add extend config for agent (#1010)

This commit is contained in:
LXW
2024-05-20 12:58:21 +08:00
committed by GitHub
parent 2d8c5c379c
commit 53b6c03288
37 changed files with 264 additions and 711 deletions

View File

@@ -0,0 +1,31 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.utils.S2ChatModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service
public abstract class BaseSqlGeneration implements SqlGeneration, InitializingBean {
protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
protected SqlExamplarLoader sqlExamplarLoader;
@Autowired
protected OptimizationConfig optimizationConfig;
@Autowired
protected SqlPromptGenerator sqlPromptGenerator;
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
return S2ChatModelProvider.provide(llmConfig);
}
}

View File

@@ -96,6 +96,7 @@ public class LLMRequestService {
}
llmReq.setCurrentDate(currentDate);
llmReq.setSqlGenerationMode(optimizationConfig.getSqlGenerationMode().getName());
llmReq.setLlmConfig(queryCtx.getLlmConfig());
return llmReq;
}

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
@@ -12,10 +11,6 @@ import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
@@ -26,20 +21,7 @@ import java.util.stream.Collectors;
@Service
@Slf4j
public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExamplarLoader sqlExamplarLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
public class OnePassSCSqlGeneration extends BaseSqlGeneration {
@Override
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
@@ -59,7 +41,8 @@ public class OnePassSCSqlGeneration implements SqlGeneration, InitializingBean {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingSqlPrompt))
.apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
llmResults.add(result);
keyPipelineLog.info("model response:{}", result);

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
@@ -12,10 +11,6 @@ import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
@@ -24,26 +19,13 @@ import java.util.Map;
@Service
@Slf4j
public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExamplarLoader sqlExampleLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
public class OnePassSqlGeneration extends BaseSqlGeneration {
@Override
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
//1.retriever sqlExamples
keyPipelineLog.info("dataSetId:{},llmReq:{}", dataSetId, llmReq);
List<Map<String, String>> sqlExamples = sqlExampleLoader.retrieverSqlExamples(llmReq.getQueryText(),
List<Map<String, String>> sqlExamples = sqlExamplarLoader.retrieverSqlExamples(llmReq.getQueryText(),
optimizationConfig.getText2sqlExampleNum());
//2.generator linking and sql prompt by sqlExamples,and generate response.
@@ -51,6 +33,7 @@ public class OnePassSqlGeneration implements SqlGeneration, InitializingBean {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(promptStr)).apply(new HashMap<>());
keyPipelineLog.info("request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
String result = response.content().text();
keyPipelineLog.info("model response:{}", result);

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
@@ -11,10 +10,6 @@ import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
@@ -23,20 +18,7 @@ import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
@Service
public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExamplarLoader sqlExamplarLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
public class TwoPassSCSqlGeneration extends BaseSqlGeneration {
@Override
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
@@ -51,6 +33,7 @@ public class TwoPassSCSqlGeneration implements SqlGeneration, InitializingBean {
//2.generator linking prompt,and parallel generate response.
List<String> linkingPromptPool = sqlPromptGenerator.generatePromptPool(llmReq, exampleListPool, false);
List<String> linkingResults = new CopyOnWriteArrayList<>();
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
linkingPromptPool.parallelStream().forEach(
linkingPrompt -> {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.core.chat.parser.llm;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMReq.SqlGenerationMode;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
@@ -11,10 +10,6 @@ import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap;
@@ -23,20 +18,7 @@ import java.util.Map;
@Service
@Slf4j
public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired
private ChatLanguageModel chatLanguageModel;
@Autowired
private SqlExamplarLoader sqlExamplarLoader;
@Autowired
private OptimizationConfig optimizationConfig;
@Autowired
private SqlPromptGenerator sqlPromptGenerator;
public class TwoPassSqlGeneration extends BaseSqlGeneration {
@Override
public LLMResp generation(LLMReq llmReq, Long dataSetId) {
@@ -48,6 +30,7 @@ public class TwoPassSqlGeneration implements SqlGeneration, InitializingBean {
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPromptStr)).apply(new HashMap<>());
keyPipelineLog.info("step one request prompt:{}", prompt.toSystemMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
keyPipelineLog.info("step one model response:{}", response.content().text());
String schemaLinkStr = OutputFormat.getSchemaLink(response.content().text());

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.core.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import lombok.Data;
import java.util.List;
@@ -22,6 +23,8 @@ public class LLMReq {
private String sqlGenerationMode;
private LLMConfig llmConfig;
@Data
public static class ElementValue {

View File

@@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
@@ -47,6 +48,7 @@ public class QueryContext {
@JsonIgnore
private WorkflowState workflowState;
private QueryDataType queryDataType = QueryDataType.ALL;
private LLMConfig llmConfig;
public List<SemanticQuery> getCandidateQueries() {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);

View File

@@ -0,0 +1,41 @@
package com.tencent.supersonic.headless.core.utils;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.LLMConfig;
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import org.apache.commons.lang3.StringUtils;
import java.time.Duration;
public class S2ChatModelProvider {
public static ChatLanguageModel provide(LLMConfig llmConfig) {
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
if (StringUtils.isBlank(llmConfig.getProvider())
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
return chatLanguageModel;
}
if (S2ModelProvider.OPEN_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
return OpenAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.apiKey(llmConfig.getApiKey())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
} else if (S2ModelProvider.LOCAL_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
return LocalAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
throw new RuntimeException("unsupported provider: " + llmConfig.getProvider());
}
}