(improvement)(chat) Support configuring embeddingModel or embeddingStore at the agent level. (#1361)

This commit is contained in:
lexluo09
2024-07-06 20:44:23 +08:00
committed by GitHub
parent d39db734c4
commit 6db6aaf98d
42 changed files with 669 additions and 299 deletions

View File

@@ -4,7 +4,9 @@ package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.pojo.RecordInfo; import com.tencent.supersonic.common.pojo.RecordInfo;
@@ -33,7 +35,9 @@ public class Agent extends RecordInfo {
private Integer status; private Integer status;
private List<String> examples; private List<String> examples;
private String agentConfig; private String agentConfig;
private LLMConfig llmConfig; private ChatModelConfig llmConfig;
private ModelConfig modelConfig;
private EmbeddingStoreConfig embeddingStore;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig; private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig; private VisualConfig visualConfig;

View File

@@ -15,7 +15,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.provider.ChatLanguageModelProvider; import dev.langchain4j.provider.ModelProvider;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@@ -46,7 +46,7 @@ public class PlainTextExecutor implements ChatExecutor {
AgentService agentService = ContextUtils.getBean(AgentService.class); AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId()); Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(chatAgent.getLlmConfig()); ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatAgent.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
QueryResult result = new QueryResult(); QueryResult result = new QueryResult();

View File

@@ -7,7 +7,7 @@ import com.tencent.supersonic.chat.server.service.MemoryService;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.provider.ChatLanguageModelProvider; import dev.langchain4j.provider.ModelProvider;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@@ -56,7 +56,7 @@ public class MemoryReviewTask {
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr); keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide( ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(
chatAgent.getLlmConfig()); chatAgent.getLlmConfig());
if (Objects.nonNull(chatLanguageModel)) { if (Objects.nonNull(chatLanguageModel)) {
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text(); String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();

View File

@@ -1,14 +1,12 @@
package com.tencent.supersonic.chat.server.parser; package com.tencent.supersonic.chat.server.parser;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -26,7 +24,14 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.provider.ChatLanguageModelProvider; import dev.langchain4j.provider.ModelProvider;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@@ -34,12 +39,8 @@ import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.Builder;
import lombok.Data; import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
public class NL2SQLParser implements ChatParser { public class NL2SQLParser implements ChatParser {
@@ -180,7 +181,7 @@ public class NL2SQLParser implements ChatParser {
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr); keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(context.getLlmConfig()); ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(context.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String result = response.content().text(); String result = response.content().text();
@@ -242,7 +243,7 @@ public class NL2SQLParser implements ChatParser {
private String curtSchema; private String curtSchema;
private String histSchema; private String histSchema;
private String histSQL; private String histSQL;
private LLMConfig llmConfig; private ChatModelConfig llmConfig;
} }
} }

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper; import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PathVariable;
@@ -15,6 +15,7 @@ import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import java.util.List; import java.util.List;
@@ -50,7 +51,7 @@ public class AgentController {
} }
@PostMapping("/testLLMConn") @PostMapping("/testLLMConn")
public boolean testLLMConn(@RequestBody LLMConfig llmConfig) { public boolean testLLMConn(@RequestBody ChatModelConfig llmConfig) {
return LLMConnHelper.testConnection(llmConfig); return LLMConnHelper.testConnection(llmConfig);
} }

View File

@@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService; import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper; import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
@@ -80,6 +80,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
/** /**
* the example in the agent will be executed by default, * the example in the agent will be executed by default,
* if the result is correct, it will be put into memory as a reference for LLM * if the result is correct, it will be put into memory as a reference for LLM
*
* @param agent * @param agent
*/ */
private void executeAgentExamplesAsync(Agent agent) { private void executeAgentExamplesAsync(Agent agent) {
@@ -121,7 +122,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
BeanUtils.copyProperties(agentDO, agent); BeanUtils.copyProperties(agentDO, agent);
agent.setAgentConfig(agentDO.getConfig()); agent.setAgentConfig(agentDO.getConfig());
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class)); agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), LLMConfig.class)); agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), ChatModelConfig.class));
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class)); agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class)); agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class)); agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));

View File

@@ -1,20 +1,20 @@
package com.tencent.supersonic.chat.server.util; package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.provider.ChatLanguageModelProvider; import dev.langchain4j.provider.ModelProvider;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@Slf4j @Slf4j
public class LLMConnHelper { public class LLMConnHelper {
public static boolean testConnection(LLMConfig llmConfig) { public static boolean testConnection(ChatModelConfig chatModel) {
try { try {
if (llmConfig == null || StringUtils.isBlank(llmConfig.getBaseUrl())) { if (chatModel == null || StringUtils.isBlank(chatModel.getBaseUrl())) {
return false; return false;
} }
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(llmConfig); ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatModel);
String response = chatLanguageModel.generate("Hi there"); String response = chatLanguageModel.generate("Hi there");
return StringUtils.isNotEmpty(response) ? true : false; return StringUtils.isNotEmpty(response) ? true : false;
} catch (Exception e) { } catch (Exception e) {

View File

@@ -185,6 +185,8 @@
<dependency> <dependency>
<groupId>dev.langchain4j</groupId> <groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings</artifactId> <artifactId>langchain4j-embeddings</artifactId>
</dependency> </dependency>
<dependency> <dependency>
<groupId>dev.langchain4j</groupId> <groupId>dev.langchain4j</groupId>

View File

@@ -0,0 +1,25 @@
package com.tencent.supersonic.common.config;
import com.tencent.supersonic.common.util.AESEncryptionUtil;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ChatModelConfig implements Serializable {
private static final long serialVersionUID = 1L;
private String provider;
private String baseUrl;
private String apiKey;
private String modelName;
private Double temperature = 0.0d;
private Long timeOut = 60L;
public String keyDecrypt() {
return AESEncryptionUtil.aesDecryptECB(getApiKey());
}
}

View File

@@ -0,0 +1,24 @@
package com.tencent.supersonic.common.config;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class EmbeddingModelConfig implements Serializable {
private static final long serialVersionUID = 1L;
private String provider;
private String baseUrl;
private String apiKey;
private String modelName;
private String modelPath;
private String vocabularyPath;
private Integer maxRetries;
private Integer maxToken;
private Boolean logRequests;
private Boolean logResponses;
}

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.common.config;
import lombok.Data;
import java.io.Serializable;
@Data
public class EmbeddingStoreConfig implements Serializable {
private static final long serialVersionUID = 1L;
private String provider;
private String persistPath;
private String baseUrl;
private String apiKey;
private Long timeOut = 60L;
}

View File

@@ -1,45 +0,0 @@
package com.tencent.supersonic.common.config;
import com.tencent.supersonic.common.util.AESEncryptionUtil;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class LLMConfig {
private String provider;
private String baseUrl;
private String apiKey;
private String modelName;
private Double temperature = 0.0d;
private Long timeOut = 60L;
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName) {
this.provider = provider;
this.baseUrl = baseUrl;
this.apiKey = apiKey;
this.modelName = modelName;
}
public LLMConfig(String provider, String baseUrl, String apiKey, String modelName,
double temperature) {
this.provider = provider;
this.baseUrl = baseUrl;
this.apiKey = apiKey;
this.modelName = modelName;
this.temperature = temperature;
}
public String keyDecrypt() {
return AESEncryptionUtil.aesDecryptECB(apiKey);
}
}

View File

@@ -0,0 +1,17 @@
package com.tencent.supersonic.common.config;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class ModelConfig implements Serializable {
private static final long serialVersionUID = 1L;
private ChatModelConfig chatModel;
private EmbeddingModelConfig embeddingModel;
}

View File

@@ -1,8 +0,0 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
public interface ChatLanguageModelFactory {
ChatLanguageModel create(LLMConfig llmConfig);
}

View File

@@ -1,33 +0,0 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import org.apache.commons.lang3.StringUtils;
import java.util.HashMap;
import java.util.Map;
public class ChatLanguageModelProvider {
private static final Map<String, ChatLanguageModelFactory> factories = new HashMap<>();
static {
factories.put(ModelProvider.OPEN_AI.name(), new OpenAiChatModelFactory());
factories.put(ModelProvider.LOCAL_AI.name(), new LocalAiChatModelFactory());
factories.put(ModelProvider.OLLAMA.name(), new OllamaChatModelFactory());
}
public static ChatLanguageModel provide(LLMConfig llmConfig) {
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
return ContextUtils.getBean(ChatLanguageModel.class);
}
ChatLanguageModelFactory factory = factories.get(llmConfig.getProvider().toUpperCase());
if (factory != null) {
return factory.create(llmConfig);
}
throw new RuntimeException("Unsupported provider: " + llmConfig.getProvider());
}
}

View File

@@ -1,20 +0,0 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import java.time.Duration;
public class LocalAiChatModelFactory implements ChatLanguageModelFactory {
@Override
public ChatLanguageModel create(LLMConfig llmConfig) {
return LocalAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
}

View File

@@ -1,9 +0,0 @@
package dev.langchain4j.model.provider;
public enum ModelProvider {
OPEN_AI,
HUGGING_FACE,
LOCAL_AI,
IN_PROCESS,
OLLAMA
}

View File

@@ -1,20 +0,0 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.ollama.OllamaChatModel;
import java.time.Duration;
public class OllamaChatModelFactory implements ChatLanguageModelFactory {
@Override
public ChatLanguageModel create(LLMConfig llmConfig) {
return OllamaChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
}

View File

@@ -1,21 +0,0 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import java.time.Duration;
public class OpenAiChatModelFactory implements ChatLanguageModelFactory {
@Override
public ChatLanguageModel create(LLMConfig llmConfig) {
return OpenAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.apiKey(llmConfig.keyDecrypt())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
}

View File

@@ -0,0 +1,44 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
import dev.langchain4j.model.azure.AzureOpenAiChatModel;
import dev.langchain4j.model.azure.AzureOpenAiEmbeddingModel;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
import java.time.Duration;
@Service
public class AzureModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder()
.endpoint(chatModel.getBaseUrl())
.apiKey(chatModel.getApiKey())
.deploymentName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeOut() == null ? 0L : chatModel.getTimeOut()));
return builder.build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
AzureOpenAiEmbeddingModel.Builder builder = AzureOpenAiEmbeddingModel.builder()
.endpoint(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.deploymentName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null
&& embeddingModelConfig.getLogResponses());
return builder.build();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.AZURE, this);
}
}

View File

@@ -0,0 +1,37 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.dashscope.QwenChatModel;
import dev.langchain4j.model.dashscope.QwenEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@Service
public class DashscopeModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
return QwenChatModel.builder()
.baseUrl(chatModel.getBaseUrl())
.apiKey(chatModel.getApiKey())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature() == null ? 0L :
chatModel.getTemperature().floatValue())
.build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return QwenEmbeddingModel.builder()
.apiKey(embeddingModelConfig.getApiKey())
.modelName(embeddingModelConfig.getModelName())
.build();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.DASHSCOPE, this);
}
}

View File

@@ -0,0 +1,8 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.EmbeddingStoreConfig;
import dev.langchain4j.store.embedding.EmbeddingStore;
public interface EmbeddingStoreFactory {
EmbeddingStore createEmbeddingStore(EmbeddingStoreConfig config);
}

View File

@@ -0,0 +1,7 @@
package dev.langchain4j.provider;
public enum EmbeddingStoreType {
IN_MEMORY,
MILVUS,
CHROMA
}

View File

@@ -0,0 +1,45 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.AllMiniLmL6V2QuantizedEmbeddingModel;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.ALL_MINILM_L6_V2;
import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.BGE_SMALL_ZH;
@Service
public class InMemoryModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
return null;
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
String modelPath = embeddingModel.getModelPath();
String vocabularyPath = embeddingModel.getVocabularyPath();
if (StringUtils.isNotBlank(modelPath) && StringUtils.isNotBlank(vocabularyPath)) {
return new S2OnnxEmbeddingModel(modelPath, vocabularyPath);
}
String modelName = embeddingModel.getModelName();
if (BGE_SMALL_ZH.equalsIgnoreCase(modelName)) {
return new BgeSmallZhEmbeddingModel();
}
if (ALL_MINILM_L6_V2.equalsIgnoreCase(modelName)) {
return new AllMiniLmL6V2QuantizedEmbeddingModel();
}
return new BgeSmallZhEmbeddingModel();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.IN_MEMORY, this);
}
}

View File

@@ -0,0 +1,39 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import dev.langchain4j.model.localai.LocalAiEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
import java.time.Duration;
@Service
public class LocalAiModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
return LocalAiChatModel
.builder()
.baseUrl(chatModel.getBaseUrl())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
.build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
return LocalAiEmbeddingModel.builder()
.baseUrl(embeddingModel.getBaseUrl())
.modelName(embeddingModel.getModelName())
.build();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.LOCAL_AI, this);
}
}

View File

@@ -0,0 +1,12 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
public interface ModelFactory {
ChatLanguageModel createChatModel(ChatModelConfig llmConfig);
EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel);
}

View File

@@ -0,0 +1,65 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import org.apache.commons.lang3.StringUtils;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
public class ModelProvider {
private static final Map<String, ModelFactory> factories = new HashMap<>();
public static void add(Provider provider, ModelFactory modelFactory) {
factories.put(provider.name(), modelFactory);
}
public static ChatLanguageModel provideChatModel(ChatModelConfig llmConfig) {
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
return ContextUtils.getBean(ChatLanguageModel.class);
}
ModelFactory modelFactory = factories.get(llmConfig.getProvider().toUpperCase());
if (modelFactory != null) {
return modelFactory.createChatModel(llmConfig);
}
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + llmConfig.getProvider());
}
public static ChatLanguageModel provideChatModelNew(ModelConfig modelConfig) {
if (modelConfig == null || modelConfig.getChatModel() == null
|| StringUtils.isBlank(modelConfig.getChatModel().getProvider())
|| StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) {
return ContextUtils.getBean(ChatLanguageModel.class);
}
ChatModelConfig chatModel = modelConfig.getChatModel();
ModelFactory modelFactory = factories.get(chatModel.getProvider().toUpperCase());
if (modelFactory != null) {
return modelFactory.createChatModel(chatModel);
}
throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider());
}
public static EmbeddingModel provideEmbeddingModel(ModelConfig modelConfig) {
if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel())
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl())
|| StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) {
return ContextUtils.getBean(EmbeddingModel.class);
}
EmbeddingModelConfig embeddingModel = modelConfig.getEmbeddingModel();
ModelFactory modelFactory = factories.get(embeddingModel.getProvider().toUpperCase());
if (modelFactory != null) {
return modelFactory.createEmbeddingModel(embeddingModel);
}
throw new RuntimeException("Unsupported EmbeddingModel provider: " + embeddingModel.getProvider());
}
}

View File

@@ -0,0 +1,42 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.ollama.OllamaChatModel;
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
import java.time.Duration;
@Service
public class OllamaModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
return OllamaChatModel
.builder()
.baseUrl(chatModel.getBaseUrl())
.modelName(chatModel.getModelName())
.temperature(chatModel.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
.build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return OllamaEmbeddingModel.builder()
.baseUrl(embeddingModelConfig.getBaseUrl())
.modelName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses())
.build();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.OLLAMA, this);
}
}

View File

@@ -0,0 +1,44 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.openai.OpenAiEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
import java.time.Duration;
@Service
public class OpenAiModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
return OpenAiChatModel
.builder()
.baseUrl(chatModel.getBaseUrl())
.modelName(chatModel.getModelName())
.apiKey(chatModel.keyDecrypt())
.temperature(chatModel.getTemperature())
.timeout(Duration.ofSeconds(chatModel.getTimeOut()))
.build();
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) {
return OpenAiEmbeddingModel.builder()
.baseUrl(embeddingModel.getBaseUrl())
.apiKey(embeddingModel.getApiKey())
.modelName(embeddingModel.getModelName())
.maxRetries(embeddingModel.getMaxRetries())
.logRequests(embeddingModel.getLogRequests())
.logResponses(embeddingModel.getLogResponses())
.build();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.OPEN_AI, this);
}
}

View File

@@ -0,0 +1,12 @@
package dev.langchain4j.provider;
public enum Provider {
OPEN_AI,
OLLAMA,
LOCAL_AI,
IN_MEMORY,
ZHIPU,
AZURE,
QIANFAN,
DASHSCOPE
}

View File

@@ -0,0 +1,34 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.qianfan.QianfanEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@Service
public class QianfanModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
return null;
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return QianfanEmbeddingModel.builder()
.baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.modelName(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses())
.build();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.QIANFAN, this);
}
}

View File

@@ -0,0 +1,34 @@
package dev.langchain4j.provider;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.EmbeddingModelConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.zhipu.ZhipuAiEmbeddingModel;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.stereotype.Service;
@Service
public class ZhipuModelFactory implements ModelFactory, InitializingBean {
@Override
public ChatLanguageModel createChatModel(ChatModelConfig chatModel) {
return null;
}
@Override
public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) {
return ZhipuAiEmbeddingModel.builder()
.baseUrl(embeddingModelConfig.getBaseUrl())
.apiKey(embeddingModelConfig.getApiKey())
.model(embeddingModelConfig.getModelName())
.maxRetries(embeddingModelConfig.getMaxRetries())
.logRequests(embeddingModelConfig.getLogRequests())
.logResponses(embeddingModelConfig.getLogResponses())
.build();
}
@Override
public void afterPropertiesSet() {
ModelProvider.add(Provider.ZHIPU, this);
}
}

View File

@@ -1,4 +1,3 @@
version: '3.8'
services: services:
chroma: chroma:
image: chromadb/chroma:0.5.3 image: chromadb/chroma:0.5.3

View File

@@ -3,12 +3,12 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import lombok.Data; import lombok.Data;
@@ -27,7 +27,7 @@ public class QueryReq {
private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SchemaMapInfo mapInfo = new SchemaMapInfo();
private QueryDataType queryDataType = QueryDataType.ALL; private QueryDataType queryDataType = QueryDataType.ALL;
private LLMConfig llmConfig; private ChatModelConfig llmConfig;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList(); private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
} }

View File

@@ -2,10 +2,11 @@ package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
@@ -50,7 +51,8 @@ public class QueryContext {
@JsonIgnore @JsonIgnore
private WorkflowState workflowState; private WorkflowState workflowState;
private QueryDataType queryDataType = QueryDataType.ALL; private QueryDataType queryDataType = QueryDataType.ALL;
private LLMConfig llmConfig; private ModelConfig modelConfig;
private ChatModelConfig llmConfig;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars; private List<SqlExemplar> dynamicExemplars;

View File

@@ -1,8 +1,5 @@
package com.tencent.supersonic.headless.chat.parser.llm; package com.tencent.supersonic.headless.chat.parser.llm;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum; import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -17,6 +14,12 @@ import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.chat.utils.ComponentFactory; import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
@@ -26,11 +29,9 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
import org.springframework.beans.factory.annotation.Autowired; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
@Service @Service
@@ -101,6 +102,7 @@ public class LLMRequestService {
llmReq.setCurrentDate(DateUtils.getBeforeDate(0)); llmReq.setCurrentDate(DateUtils.getBeforeDate(0));
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE))); llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setModelConfig(queryCtx.getModelConfig());
llmReq.setLlmConfig(queryCtx.getLlmConfig()); llmReq.setLlmConfig(queryCtx.getLlmConfig());
llmReq.setPromptConfig(queryCtx.getPromptConfig()); llmReq.setPromptConfig(queryCtx.getPromptConfig());
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars()); llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
@@ -118,7 +120,7 @@ public class LLMRequestService {
} }
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId, protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
LLMParserConfig llmParserConfig) { LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig); Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.headless.chat.parser.llm; package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.provider.ChatLanguageModelProvider; import dev.langchain4j.provider.ModelProvider;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
@@ -23,8 +23,8 @@ public abstract class SqlGenStrategy implements InitializingBean {
@Autowired @Autowired
protected PromptHelper promptHelper; protected PromptHelper promptHelper;
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) { protected ChatLanguageModel getChatLanguageModel(ChatModelConfig llmConfig) {
return ChatLanguageModelProvider.provide(llmConfig); return ModelProvider.provideChatModel(llmConfig);
} }
abstract LLMResp generate(LLMReq llmReq); abstract LLMResp generate(LLMReq llmReq);

View File

@@ -2,7 +2,8 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
@@ -27,8 +28,8 @@ public class LLMReq {
private SqlGenType sqlGenType; private SqlGenType sqlGenType;
private LLMConfig llmConfig; private ModelConfig modelConfig;
private ChatModelConfig llmConfig;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars; private List<SqlExemplar> dynamicExemplars;

View File

@@ -1,6 +1,6 @@
spring: spring:
datasource: datasource:
url: jdbc:mysql://${DB_HOST}:3306/${DB_NAME}?useUnicode=true&characterEncoding=UTF-8&useSSL=false&allowMultiQueries=true&allowPublicKeyRetrieval=true url: jdbc:mysql://${DB_HOST}:${DB_PORT:3306}/${DB_NAME}?useUnicode=true&characterEncoding=UTF-8&useSSL=false&allowMultiQueries=true&allowPublicKeyRetrieval=true
username: ${DB_USERNAME} username: ${DB_USERNAME}
password: ${DB_PASSWORD} password: ${DB_PASSWORD}
driver-class-name: com.mysql.jdbc.Driver driver-class-name: com.mysql.jdbc.Driver

View File

@@ -14,11 +14,11 @@ langchain4j:
temperature: ${OPENAI_TEMPERATURE:0.0} temperature: ${OPENAI_TEMPERATURE:0.0}
timeout: ${OPENAI_TIMEOUT:PT60S} timeout: ${OPENAI_TIMEOUT:PT60S}
# embedding-model: # embedding-model:
# base-url: https://api.openai.com/v1 # base-url: https://api.openai.com/v1
# api-key: demo # api-key: demo
# model-name: text-embedding-3-small # model-name: text-embedding-3-small
# timeout: PT60S # timeout: PT60S
in-memory: in-memory:
embedding-model: embedding-model:

View File

@@ -8,7 +8,8 @@ import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool; import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.ChatModelConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeAll;
@@ -85,7 +86,8 @@ public class Text2SQLEval extends BaseTest {
AgentConfig agentConfig = new AgentConfig(); AgentConfig agentConfig = new AgentConfig();
agentConfig.getTools().add(getLLMQueryTool()); agentConfig.getTools().add(getLLMQueryTool());
agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agent.setLlmConfig(getLLMConfig(LLMType.GPT)); agent.setModelConfig(getLLMConfig(LLMType.GPT));
agent.setLlmConfig(getLLMConfig(LLMType.GPT).getChatModel());
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(); MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
multiTurnConfig.setEnableMultiTurn(enableMultiturn); multiTurnConfig.setEnableMultiTurn(enableMultiturn);
agent.setMultiTurnConfig(multiTurnConfig); agent.setMultiTurnConfig(multiTurnConfig);
@@ -108,7 +110,7 @@ public class Text2SQLEval extends BaseTest {
GLM GLM
} }
private static LLMConfig getLLMConfig(LLMType type) { private static ModelConfig getLLMConfig(LLMType type) {
String baseUrl; String baseUrl;
String apiKey; String apiKey;
String modelName; String modelName;
@@ -143,9 +145,16 @@ public class Text2SQLEval extends BaseTest {
modelName = "gpt-3.5-turbo"; modelName = "gpt-3.5-turbo";
temperature = 0.0; temperature = 0.0;
} }
ChatModelConfig chatModel = new ChatModelConfig();
chatModel.setModelName(modelName);
chatModel.setBaseUrl(baseUrl);
chatModel.setApiKey(apiKey);
chatModel.setTemperature(temperature);
chatModel.setProvider("open_ai");
return new LLMConfig("open_ai", ModelConfig modelConfig = new ModelConfig();
baseUrl, apiKey, modelName, temperature); modelConfig.setChatModel(chatModel);
return modelConfig;
} }
} }

View File

@@ -52,7 +52,7 @@
<glyph glyph-name="search" unicode="&#59337;" d="M454.981818 26.88a420.072727 420.072727 0 1 0 420.072727 420.072727 420.538182 420.538182 0 0 0-420.072727-420.072727z m0 781.847273a361.890909 361.890909 0 1 1 361.890909-361.890909A362.24 362.24 0 0 1 454.981818 808.727273zM959.650909-96.814545a28.974545 28.974545 0 0 0-20.596364 8.494545L705.745455 145.105455A29.090909 29.090909 0 0 0 746.821818 186.181818l233.425455-233.425454a29.090909 29.090909 0 0 0-20.596364-49.687273z" horiz-adv-x="1024" /> <glyph glyph-name="search" unicode="&#59337;" d="M454.981818 26.88a420.072727 420.072727 0 1 0 420.072727 420.072727 420.538182 420.538182 0 0 0-420.072727-420.072727z m0 781.847273a361.890909 361.890909 0 1 1 361.890909-361.890909A362.24 362.24 0 0 1 454.981818 808.727273zM959.650909-96.814545a28.974545 28.974545 0 0 0-20.596364 8.494545L705.745455 145.105455A29.090909 29.090909 0 0 0 746.821818 186.181818l233.425455-233.425454a29.090909 29.090909 0 0 0-20.596364-49.687273z" horiz-adv-x="1024" />
<glyph glyph-name="factory-color" unicode="&#59037;" d="M244.680556 205.021567h540.631929c21.08898 0 41.290852 18.548625 40.323097 40.323097s-17.722001 40.323097-40.323097 40.323098H244.680556c-21.08898 0-41.290852-18.548625-40.323098-40.323098s17.722001-40.323097 40.323098-40.323097zM1029.912394-17.985324l-0.161292 165.828739-0.181454 201.615487-0.181454 173.389319c0 28.226168 0.745977 56.613629 0 84.839797v1.209693c0 13.689692-7.540419 28.226168-19.980095 34.798833l-162.905314 85.464805L591.176932 863.134841c-17.601032 9.233989-35.02061 19.173633-52.863581 27.802776-14.3147 6.875088-27.580999 6.754119-41.875536-0.403231-2.237932-1.108885-4.415379-2.318578-6.63315-3.467787l-29.879415-15.665523-245.587825-129.033912c-63.710494-33.508494-128.22745-65.827457-191.272613-100.545644-1.008077-0.564523-2.016155-1.088724-3.084717-1.633085C7.580742 633.61577 0 618.998647 0 605.308956l0.120969-68.549266 0.282262-165.869062 0.362908-199.720301 0.282262-172.824796c0-28.226168-0.685493-56.452336 0.161292-84.678505v-1.209693c0-21.08898 18.528463-41.290852 40.323097-40.323097s40.323097 17.722001 40.323098 40.323097l-0.120969 68.549266-0.282262 165.808577-0.362908 199.720301-0.302423 172.824796c0 20.504295 0.625008 41.190044 0.463715 61.795147l140.969549 74.113853 257.160554 135.082377 38.306943 20.161549 139.114686-72.964645 253.914545-133.247676 38.024681-19.959933v-45.141708l0.161292-165.808577q0.120969-100.807744 0.201616-201.615487 0-86.69466 0.161292-173.389319c0-28.226168-0.725816-56.593467 0.100808-84.839797v-1.189532c0-21.08898 18.528463-41.290852 40.323098-40.323097s40.323097 17.722001 40.323097 40.323097q-0.080646 34.919802-0.100808 69.658151zM825.635582 74.2336v353.835181a40.927944 40.927944 0 0 1-40.323097 40.323097H244.680556a40.907782 40.907782 0 0 1-40.323098-40.323097v-511.538815c0-21.08898 18.548625-41.290852 40.323098-40.323098s40.323097 17.722001 40.323097 40.323098v117.380537h459.985735v-117.380537c0-21.08898 18.548625-41.290852 40.323097-40.323098s40.323097 17.722001 40.323097 40.323098v155.001987c0.020162 0.887108 0.040323 1.875024 0 2.701647z m-223.510929 40.323098H285.003653v273.188985h459.985735v-273.188985z" horiz-adv-x="1030" /> <glyph glyph-name="modelFactory-color" unicode="&#59037;" d="M244.680556 205.021567h540.631929c21.08898 0 41.290852 18.548625 40.323097 40.323097s-17.722001 40.323097-40.323097 40.323098H244.680556c-21.08898 0-41.290852-18.548625-40.323098-40.323098s17.722001-40.323097 40.323098-40.323097zM1029.912394-17.985324l-0.161292 165.828739-0.181454 201.615487-0.181454 173.389319c0 28.226168 0.745977 56.613629 0 84.839797v1.209693c0 13.689692-7.540419 28.226168-19.980095 34.798833l-162.905314 85.464805L591.176932 863.134841c-17.601032 9.233989-35.02061 19.173633-52.863581 27.802776-14.3147 6.875088-27.580999 6.754119-41.875536-0.403231-2.237932-1.108885-4.415379-2.318578-6.63315-3.467787l-29.879415-15.665523-245.587825-129.033912c-63.710494-33.508494-128.22745-65.827457-191.272613-100.545644-1.008077-0.564523-2.016155-1.088724-3.084717-1.633085C7.580742 633.61577 0 618.998647 0 605.308956l0.120969-68.549266 0.282262-165.869062 0.362908-199.720301 0.282262-172.824796c0-28.226168-0.685493-56.452336 0.161292-84.678505v-1.209693c0-21.08898 18.528463-41.290852 40.323097-40.323097s40.323097 17.722001 40.323098 40.323097l-0.120969 68.549266-0.282262 165.808577-0.362908 199.720301-0.302423 172.824796c0 20.504295 0.625008 41.190044 0.463715 61.795147l140.969549 74.113853 257.160554 135.082377 38.306943 20.161549 139.114686-72.964645 253.914545-133.247676 38.024681-19.959933v-45.141708l0.161292-165.808577q0.120969-100.807744 0.201616-201.615487 0-86.69466 0.161292-173.389319c0-28.226168-0.725816-56.593467 0.100808-84.839797v-1.189532c0-21.08898 18.528463-41.290852 40.323098-40.323097s40.323097 17.722001 40.323097 40.323097q-0.080646 34.919802-0.100808 69.658151zM825.635582 74.2336v353.835181a40.927944 40.927944 0 0 1-40.323097 40.323097H244.680556a40.907782 40.907782 0 0 1-40.323098-40.323097v-511.538815c0-21.08898 18.548625-41.290852 40.323098-40.323098s40.323097 17.722001 40.323097 40.323098v117.380537h459.985735v-117.380537c0-21.08898 18.548625-41.290852 40.323097-40.323098s40.323097 17.722001 40.323097 40.323098v155.001987c0.020162 0.887108 0.040323 1.875024 0 2.701647z m-223.510929 40.323098H285.003653v273.188985h459.985735v-273.188985z" horiz-adv-x="1030" />
<glyph glyph-name="portray-color" unicode="&#59038;" d="M1021.214089-75.545629a544.37731 544.37731 0 0 1-192.470554 267.792069 532.086853 532.086853 0 0 1-141.749936 75.701221A338.237371 338.237371 0 0 1 851.345985 559.001667C851.345985 744.817388 702.761356 896 520.063215 896S188.720491 744.817388 188.720491 559.001667a338.337293 338.337293 0 0 1 158.237134-287.256955A531.307458 531.307458 0 0 1 194.55596 192.546207 544.517202 544.517202 0 0 1 2.005468-74.626343a39.349447 39.349447 0 1 1 74.721981-24.70082c63.011074 190.432136 237.815345 318.392779 434.742438 318.392779a448.871467 448.871467 0 0 0 270.230176-89.930173 466.058122 466.058122 0 0 0 164.77206-229.261986 39.349447 39.349447 0 1 1 74.741966 24.580914zM267.419384 559.001667c0 142.409424 113.332002 258.29944 252.643831 258.29944s252.623846-115.91 252.623846-258.29944-113.332002-258.29944-252.623846-258.299439-252.643831 115.870031-252.643831 258.299439zM617.247754 2.453807l-197.40672-0.339736a23.981379 23.981379 0 0 0-20.663956 36.231867l100.761763 169.708228a23.981379 23.981379 0 0 0 41.447817-0.359721l96.664943-169.368491a23.981379 23.981379 0 0 0-20.803847-35.872147z" horiz-adv-x="1024" /> <glyph glyph-name="portray-color" unicode="&#59038;" d="M1021.214089-75.545629a544.37731 544.37731 0 0 1-192.470554 267.792069 532.086853 532.086853 0 0 1-141.749936 75.701221A338.237371 338.237371 0 0 1 851.345985 559.001667C851.345985 744.817388 702.761356 896 520.063215 896S188.720491 744.817388 188.720491 559.001667a338.337293 338.337293 0 0 1 158.237134-287.256955A531.307458 531.307458 0 0 1 194.55596 192.546207 544.517202 544.517202 0 0 1 2.005468-74.626343a39.349447 39.349447 0 1 1 74.721981-24.70082c63.011074 190.432136 237.815345 318.392779 434.742438 318.392779a448.871467 448.871467 0 0 0 270.230176-89.930173 466.058122 466.058122 0 0 0 164.77206-229.261986 39.349447 39.349447 0 1 1 74.741966 24.580914zM267.419384 559.001667c0 142.409424 113.332002 258.29944 252.643831 258.29944s252.623846-115.91 252.623846-258.29944-113.332002-258.29944-252.623846-258.299439-252.643831 115.870031-252.643831 258.299439zM617.247754 2.453807l-197.40672-0.339736a23.981379 23.981379 0 0 0-20.663956 36.231867l100.761763 169.708228a23.981379 23.981379 0 0 0 41.447817-0.359721l96.664943-169.368491a23.981379 23.981379 0 0 0-20.803847-35.872147z" horiz-adv-x="1024" />

Before

Width:  |  Height:  |  Size: 97 KiB

After

Width:  |  Height:  |  Size: 96 KiB