mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improvement)(chat) Support configuring embeddingModel or embeddingStore at the agent level. (#1361)
This commit is contained in:
@@ -4,7 +4,9 @@ package com.tencent.supersonic.chat.server.agent;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.config.VisualConfig;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
@@ -33,7 +35,9 @@ public class Agent extends RecordInfo {
|
||||
private Integer status;
|
||||
private List<String> examples;
|
||||
private String agentConfig;
|
||||
private LLMConfig llmConfig;
|
||||
private ChatModelConfig llmConfig;
|
||||
private ModelConfig modelConfig;
|
||||
private EmbeddingStoreConfig embeddingStore;
|
||||
private PromptConfig promptConfig;
|
||||
private MultiTurnConfig multiTurnConfig;
|
||||
private VisualConfig visualConfig;
|
||||
|
||||
@@ -15,7 +15,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
@@ -46,7 +46,7 @@ public class PlainTextExecutor implements ChatExecutor {
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(chatAgent.getLlmConfig());
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatAgent.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
QueryResult result = new QueryResult();
|
||||
|
||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -56,7 +56,7 @@ public class MemoryReviewTask {
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
|
||||
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(
|
||||
chatAgent.getLlmConfig());
|
||||
if (Objects.nonNull(chatLanguageModel)) {
|
||||
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -26,7 +24,14 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
@@ -34,12 +39,8 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
@Slf4j
|
||||
public class NL2SQLParser implements ChatParser {
|
||||
@@ -180,7 +181,7 @@ public class NL2SQLParser implements ChatParser {
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(context.getLlmConfig());
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(context.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
String result = response.content().text();
|
||||
@@ -242,7 +243,7 @@ public class NL2SQLParser implements ChatParser {
|
||||
private String curtSchema;
|
||||
private String histSchema;
|
||||
private String histSQL;
|
||||
private LLMConfig llmConfig;
|
||||
private ChatModelConfig llmConfig;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
@@ -15,6 +15,7 @@ import org.springframework.web.bind.annotation.PutMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import java.util.List;
|
||||
@@ -50,7 +51,7 @@ public class AgentController {
|
||||
}
|
||||
|
||||
@PostMapping("/testLLMConn")
|
||||
public boolean testLLMConn(@RequestBody LLMConfig llmConfig) {
|
||||
public boolean testLLMConn(@RequestBody ChatModelConfig llmConfig) {
|
||||
return LLMConnHelper.testConnection(llmConfig);
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.config.VisualConfig;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
@@ -80,6 +80,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
/**
|
||||
* the example in the agent will be executed by default,
|
||||
* if the result is correct, it will be put into memory as a reference for LLM
|
||||
*
|
||||
* @param agent
|
||||
*/
|
||||
private void executeAgentExamplesAsync(Agent agent) {
|
||||
@@ -121,7 +122,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
BeanUtils.copyProperties(agentDO, agent);
|
||||
agent.setAgentConfig(agentDO.getConfig());
|
||||
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
||||
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), LLMConfig.class));
|
||||
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), ChatModelConfig.class));
|
||||
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
|
||||
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
||||
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
public class LLMConnHelper {
|
||||
public static boolean testConnection(LLMConfig llmConfig) {
|
||||
public static boolean testConnection(ChatModelConfig chatModel) {
|
||||
try {
|
||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||
if (chatModel == null || StringUtils.isBlank(chatModel.getBaseUrl())) {
|
||||
return false;
|
||||
}
|
||||
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(llmConfig);
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatModel);
|
||||
String response = chatLanguageModel.generate("Hi there");
|
||||
return StringUtils.isNotEmpty(response) ? true : false;
|
||||
} catch (Exception e) {
|
||||
|
||||
@@ -185,6 +185,8 @@
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
<artifactId>langchain4j-embeddings</artifactId>
|
||||
|
||||
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>dev.langchain4j</groupId>
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class ChatModelConfig implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
private String provider;
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String modelName;
|
||||
private Double temperature = 0.0d;
|
||||
private Long timeOut = 60L;
|
||||
|
||||
public String keyDecrypt() {
|
||||
return AESEncryptionUtil.aesDecryptECB(getApiKey());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class EmbeddingModelConfig implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
private String provider;
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private String modelName;
|
||||
private String modelPath;
|
||||
private String vocabularyPath;
|
||||
private Integer maxRetries;
|
||||
private Integer maxToken;
|
||||
private Boolean logRequests;
|
||||
private Boolean logResponses;
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
public class EmbeddingStoreConfig implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
private String provider;
|
||||
private String persistPath;
|
||||
private String baseUrl;
|
||||
private String apiKey;
|
||||
private Long timeOut = 60L;
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class LLMConfig {
|
||||
|
||||
private String provider;
|
||||
|
||||
private String baseUrl;
|
||||
|
||||
private String apiKey;
|
||||
|
||||
private String modelName;
|
||||
|
||||
private Double temperature = 0.0d;
|
||||
|
||||
private Long timeOut = 60L;
|
||||
|
||||
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName) {
|
||||
this.provider = provider;
|
||||
this.baseUrl = baseUrl;
|
||||
this.apiKey = apiKey;
|
||||
this.modelName = modelName;
|
||||
}
|
||||
|
||||
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName,
|
||||
double temperature) {
|
||||
this.provider = provider;
|
||||
this.baseUrl = baseUrl;
|
||||
this.apiKey = apiKey;
|
||||
this.modelName = modelName;
|
||||
this.temperature = temperature;
|
||||
}
|
||||
|
||||
public String keyDecrypt() {
|
||||
return AESEncryptionUtil.aesDecryptECB(apiKey);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.tencent.supersonic.common.config;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class ModelConfig implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
private ChatModelConfig chatModel;
|
||||
private EmbeddingModelConfig embeddingModel;
|
||||
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
|
||||
public interface ChatLanguageModelFactory {
|
||||
ChatLanguageModel create(LLMConfig llmConfig);
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class ChatLanguageModelProvider {
|
||||
private static final Map<String, ChatLanguageModelFactory> factories = new HashMap<>();
|
||||
|
||||
static {
|
||||
factories.put(ModelProvider.OPEN_AI.name(), new OpenAiChatModelFactory());
|
||||
factories.put(ModelProvider.LOCAL_AI.name(), new LocalAiChatModelFactory());
|
||||
factories.put(ModelProvider.OLLAMA.name(), new OllamaChatModelFactory());
|
||||
}
|
||||
|
||||
public static ChatLanguageModel provide(LLMConfig llmConfig) {
|
||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|
||||
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||
return ContextUtils.getBean(ChatLanguageModel.class);
|
||||
}
|
||||
|
||||
ChatLanguageModelFactory factory = factories.get(llmConfig.getProvider().toUpperCase());
|
||||
if (factory != null) {
|
||||
return factory.create(llmConfig);
|
||||
}
|
||||
|
||||
throw new RuntimeException("Unsupported provider: " + llmConfig.getProvider());
|
||||
}
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
public class LocalAiChatModelFactory implements ChatLanguageModelFactory {
|
||||
@Override
|
||||
public ChatLanguageModel create(LLMConfig llmConfig) {
|
||||
return LocalAiChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
public enum ModelProvider {
|
||||
OPEN_AI,
|
||||
HUGGING_FACE,
|
||||
LOCAL_AI,
|
||||
IN_PROCESS,
|
||||
OLLAMA
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.ollama.OllamaChatModel;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
public class OllamaChatModelFactory implements ChatLanguageModelFactory {
|
||||
@Override
|
||||
public ChatLanguageModel create(LLMConfig llmConfig) {
|
||||
return OllamaChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
package dev.langchain4j.model.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
public class OpenAiChatModelFactory implements ChatLanguageModelFactory {
|
||||
@Override
|
||||
public ChatLanguageModel create(LLMConfig llmConfig) {
|
||||
return OpenAiChatModel
|
||||
.builder()
|
||||
.baseUrl(llmConfig.getBaseUrl())
|
||||
.modelName(llmConfig.getModelName())
|
||||
.apiKey(llmConfig.keyDecrypt())
|
||||
.temperature(llmConfig.getTemperature())
|
||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.azure.AzureOpenAiChatModel;
|
||||
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class AzureModelFactory implements ModelFactory, InitializingBean {
|
||||
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
|
||||
.endpoint(chatModel.getBaseUrl())
|
||||
.apiKey(chatModel.getApiKey())
|
||||
.deploymentName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut() == null ? 0L : chatModel.getTimeOut()));
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
AzureOpenAiEmbeddingModel.Builder builder = AzureOpenAiEmbeddingModel.builder()
|
||||
.endpoint(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.deploymentName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null
|
||||
&& embeddingModelConfig.getLogResponses());
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.AZURE, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.dashscope.QwenChatModel;
|
||||
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return QwenChatModel.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.apiKey(chatModel.getApiKey())
|
||||
.modelName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature() == null ? 0L :
|
||||
chatModel.getTemperature().floatValue())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return QwenEmbeddingModel.builder()
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.DASHSCOPE, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
|
||||
public interface EmbeddingStoreFactory {
|
||||
EmbeddingStore createEmbeddingStore(EmbeddingStoreConfig config);
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
public enum EmbeddingStoreType {
|
||||
IN_MEMORY,
|
||||
MILVUS,
|
||||
CHROMA
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.ALL_MINILM_L6_V2;
|
||||
import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.BGE_SMALL_ZH;
|
||||
|
||||
@Service
|
||||
public class InMemoryModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||
String modelPath = embeddingModel.getModelPath();
|
||||
String vocabularyPath = embeddingModel.getVocabularyPath();
|
||||
if (StringUtils.isNotBlank(modelPath) && StringUtils.isNotBlank(vocabularyPath)) {
|
||||
return new S2OnnxEmbeddingModel(modelPath, vocabularyPath);
|
||||
}
|
||||
String modelName = embeddingModel.getModelName();
|
||||
if (BGE_SMALL_ZH.equalsIgnoreCase(modelName)) {
|
||||
return new BgeSmallZhEmbeddingModel();
|
||||
}
|
||||
if (ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) {
|
||||
return new AllMiniLmL6V2QuantizedEmbeddingModel();
|
||||
}
|
||||
return new BgeSmallZhEmbeddingModel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.IN_MEMORY, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||
import dev.langchain4j.model.localai.LocalAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class LocalAiModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return LocalAiChatModel
|
||||
.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.modelName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||
return LocalAiEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModel.getBaseUrl())
|
||||
.modelName(embeddingModel.getModelName())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.LOCAL_AI, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
|
||||
public interface ModelFactory {
|
||||
ChatLanguageModel createChatModel(ChatModelConfig llmConfig);
|
||||
|
||||
EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel);
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class ModelProvider {
|
||||
private static final Map<String, ModelFactory> factories = new HashMap<>();
|
||||
|
||||
public static void add(Provider provider, ModelFactory modelFactory) {
|
||||
factories.put(provider.name(), modelFactory);
|
||||
}
|
||||
|
||||
public static ChatLanguageModel provideChatModel(ChatModelConfig llmConfig) {
|
||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|
||||
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||
return ContextUtils.getBean(ChatLanguageModel.class);
|
||||
}
|
||||
ModelFactory modelFactory = factories.get(llmConfig.getProvider().toUpperCase());
|
||||
if (modelFactory != null) {
|
||||
return modelFactory.createChatModel(llmConfig);
|
||||
}
|
||||
|
||||
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + llmConfig.getProvider());
|
||||
}
|
||||
|
||||
public static ChatLanguageModel provideChatModelNew(ModelConfig modelConfig) {
|
||||
if (modelConfig == null || modelConfig.getChatModel() == null
|
||||
|| StringUtils.isBlank(modelConfig.getChatModel().getProvider())
|
||||
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
|
||||
return ContextUtils.getBean(ChatLanguageModel.class);
|
||||
}
|
||||
ChatModelConfig chatModel = modelConfig.getChatModel();
|
||||
ModelFactory modelFactory = factories.get(chatModel.getProvider().toUpperCase());
|
||||
if (modelFactory != null) {
|
||||
return modelFactory.createChatModel(chatModel);
|
||||
}
|
||||
|
||||
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider());
|
||||
}
|
||||
|
||||
public static EmbeddingModel provideEmbeddingModel(ModelConfig modelConfig) {
|
||||
if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel())
|
||||
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl())
|
||||
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) {
|
||||
return ContextUtils.getBean(EmbeddingModel.class);
|
||||
}
|
||||
EmbeddingModelConfig embeddingModel = modelConfig.getEmbeddingModel();
|
||||
|
||||
ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase());
|
||||
if (modelFactory != null) {
|
||||
return modelFactory.createEmbeddingModel(embeddingModel);
|
||||
}
|
||||
|
||||
throw new RuntimeException("Unsupported EmbeddingModel provider: " + embeddingModel.getProvider());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,42 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.ollama.OllamaChatModel;
|
||||
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class OllamaModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return OllamaChatModel
|
||||
.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.modelName(chatModel.getModelName())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return OllamaEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.OLLAMA, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Service
|
||||
public class OpenAiModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return OpenAiChatModel
|
||||
.builder()
|
||||
.baseUrl(chatModel.getBaseUrl())
|
||||
.modelName(chatModel.getModelName())
|
||||
.apiKey(chatModel.keyDecrypt())
|
||||
.temperature(chatModel.getTemperature())
|
||||
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
|
||||
return OpenAiEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModel.getBaseUrl())
|
||||
.apiKey(embeddingModel.getApiKey())
|
||||
.modelName(embeddingModel.getModelName())
|
||||
.maxRetries(embeddingModel.getMaxRetries())
|
||||
.logRequests(embeddingModel.getLogRequests())
|
||||
.logResponses(embeddingModel.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.OPEN_AI, this);
|
||||
}
|
||||
}
|
||||
12
common/src/main/java/dev/langchain4j/provider/Provider.java
Normal file
12
common/src/main/java/dev/langchain4j/provider/Provider.java
Normal file
@@ -0,0 +1,12 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
public enum Provider {
|
||||
OPEN_AI,
|
||||
OLLAMA,
|
||||
LOCAL_AI,
|
||||
IN_MEMORY,
|
||||
ZHIPU,
|
||||
AZURE,
|
||||
QIANFAN,
|
||||
DASHSCOPE
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class QianfanModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return QianfanEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.modelName(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.QIANFAN, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package dev.langchain4j.provider;
|
||||
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
|
||||
@Override
|
||||
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
|
||||
return ZhipuAiEmbeddingModel.builder()
|
||||
.baseUrl(embeddingModelConfig.getBaseUrl())
|
||||
.apiKey(embeddingModelConfig.getApiKey())
|
||||
.model(embeddingModelConfig.getModelName())
|
||||
.maxRetries(embeddingModelConfig.getMaxRetries())
|
||||
.logRequests(embeddingModelConfig.getLogRequests())
|
||||
.logResponses(embeddingModelConfig.getLogResponses())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
ModelProvider.add(Provider.ZHIPU, this);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
version: '3.8'
|
||||
services:
|
||||
chroma:
|
||||
image: chromadb/chroma:0.5.3
|
||||
|
||||
@@ -3,12 +3,12 @@ package com.tencent.supersonic.headless.api.pojo.request;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -27,7 +27,7 @@ public class QueryReq {
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private LLMConfig llmConfig;
|
||||
private ChatModelConfig llmConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
|
||||
}
|
||||
|
||||
@@ -2,10 +2,11 @@ package com.tencent.supersonic.headless.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
@@ -50,7 +51,8 @@ public class QueryContext {
|
||||
@JsonIgnore
|
||||
private WorkflowState workflowState;
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private LLMConfig llmConfig;
|
||||
private ModelConfig modelConfig;
|
||||
private ChatModelConfig llmConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private List<SqlExemplar> dynamicExemplars;
|
||||
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -17,6 +14,12 @@ import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
@@ -26,11 +29,9 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
@@ -101,6 +102,7 @@ public class LLMRequestService {
|
||||
|
||||
llmReq.setCurrentDate(DateUtils.getBeforeDate(0));
|
||||
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||
llmReq.setModelConfig(queryCtx.getModelConfig());
|
||||
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
||||
llmReq.setPromptConfig(queryCtx.getPromptConfig());
|
||||
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
|
||||
@@ -118,7 +120,7 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
@@ -23,8 +23,8 @@ public abstract class SqlGenStrategy implements InitializingBean {
|
||||
@Autowired
|
||||
protected PromptHelper promptHelper;
|
||||
|
||||
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
|
||||
return ChatLanguageModelProvider.provide(llmConfig);
|
||||
protected ChatLanguageModel getChatLanguageModel(ChatModelConfig llmConfig) {
|
||||
return ModelProvider.provideChatModel(llmConfig);
|
||||
}
|
||||
|
||||
abstract LLMResp generate(LLMReq llmReq);
|
||||
|
||||
@@ -2,7 +2,8 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
@@ -27,8 +28,8 @@ public class LLMReq {
|
||||
|
||||
private SqlGenType sqlGenType;
|
||||
|
||||
private LLMConfig llmConfig;
|
||||
|
||||
private ModelConfig modelConfig;
|
||||
private ChatModelConfig llmConfig;
|
||||
private PromptConfig promptConfig;
|
||||
|
||||
private List<SqlExemplar> dynamicExemplars;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
spring:
|
||||
datasource:
|
||||
url: jdbc:mysql://${DB_HOST}:3306/${DB_NAME}?useUnicode=true&characterEncoding=UTF-8&useSSL=false&allowMultiQueries=true&allowPublicKeyRetrieval=true
|
||||
url: jdbc:mysql://${DB_HOST}:${DB_PORT:3306}/${DB_NAME}?useUnicode=true&characterEncoding=UTF-8&useSSL=false&allowMultiQueries=true&allowPublicKeyRetrieval=true
|
||||
username: ${DB_USERNAME}
|
||||
password: ${DB_PASSWORD}
|
||||
driver-class-name: com.mysql.jdbc.Driver
|
||||
|
||||
@@ -14,11 +14,11 @@ langchain4j:
|
||||
temperature: ${OPENAI_TEMPERATURE:0.0}
|
||||
timeout: ${OPENAI_TIMEOUT:PT60S}
|
||||
|
||||
# embedding-model:
|
||||
# base-url: https://api.openai.com/v1
|
||||
# api-key: demo
|
||||
# model-name: text-embedding-3-small
|
||||
# timeout: PT60S
|
||||
# embedding-model:
|
||||
# base-url: https://api.openai.com/v1
|
||||
# api-key: demo
|
||||
# model-name: text-embedding-3-small
|
||||
# timeout: PT60S
|
||||
|
||||
in-memory:
|
||||
embedding-model:
|
||||
|
||||
@@ -8,7 +8,8 @@ import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
@@ -85,7 +86,8 @@ public class Text2SQLEval extends BaseTest {
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
agentConfig.getTools().add(getLLMQueryTool());
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
agent.setLlmConfig(getLLMConfig(LLMType.GPT));
|
||||
agent.setModelConfig(getLLMConfig(LLMType.GPT));
|
||||
agent.setLlmConfig(getLLMConfig(LLMType.GPT).getChatModel());
|
||||
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
|
||||
multiTurnConfig.setEnableMultiTurn(enableMultiturn);
|
||||
agent.setMultiTurnConfig(multiTurnConfig);
|
||||
@@ -108,7 +110,7 @@ public class Text2SQLEval extends BaseTest {
|
||||
GLM
|
||||
}
|
||||
|
||||
private static LLMConfig getLLMConfig(LLMType type) {
|
||||
private static ModelConfig getLLMConfig(LLMType type) {
|
||||
String baseUrl;
|
||||
String apiKey;
|
||||
String modelName;
|
||||
@@ -143,9 +145,16 @@ public class Text2SQLEval extends BaseTest {
|
||||
modelName = "gpt-3.5-turbo";
|
||||
temperature = 0.0;
|
||||
}
|
||||
ChatModelConfig chatModel = new ChatModelConfig();
|
||||
chatModel.setModelName(modelName);
|
||||
chatModel.setBaseUrl(baseUrl);
|
||||
chatModel.setApiKey(apiKey);
|
||||
chatModel.setTemperature(temperature);
|
||||
chatModel.setProvider("open_ai");
|
||||
|
||||
return new LLMConfig("open_ai",
|
||||
baseUrl, apiKey, modelName, temperature);
|
||||
ModelConfig modelConfig = new ModelConfig();
|
||||
modelConfig.setChatModel(chatModel);
|
||||
return modelConfig;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 97 KiB After Width: | Height: | Size: 96 KiB |
Reference in New Issue
Block a user