diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index 952a3ce69..bba3a2007 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -4,8 +4,6 @@ package com.tencent.supersonic.chat.server.agent; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import com.tencent.supersonic.common.config.ChatModelConfig; -import com.tencent.supersonic.common.config.EmbeddingStoreConfig; import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.VisualConfig; @@ -35,9 +33,7 @@ public class Agent extends RecordInfo { private Integer status; private List examples; private String agentConfig; - private ChatModelConfig llmConfig; private ModelConfig modelConfig; - private EmbeddingStoreConfig embeddingStore; private PromptConfig promptConfig; private MultiTurnConfig multiTurnConfig; private VisualConfig visualConfig; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java index f7d0d635d..be849f7c6 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -46,7 +46,7 @@ public class PlainTextExecutor implements ChatExecutor { AgentService agentService = ContextUtils.getBean(AgentService.class); Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId()); - ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatAgent.getLlmConfig()); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatAgent.getModelConfig()); Response response = chatLanguageModel.generate(prompt.toUserMessage()); QueryResult result = new QueryResult(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java index a09ac4f5d..b6205885d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -56,8 +56,8 @@ public class MemoryReviewTask { Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr); - ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel( - chatAgent.getLlmConfig()); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel( + chatAgent.getModelConfig()); if (Objects.nonNull(chatLanguageModel)) { String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text(); keyPipelineLog.info("MemoryReviewTask modelResp:{}", response); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 5a5ed244a..d438098f1 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -5,8 +5,8 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryReposi import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.util.QueryReqConverter; -import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.util.ContextUtils; @@ -168,7 +168,7 @@ public class NL2SQLParser implements ChatParser { .curtSchema(curtMapStr) .histSchema(histMapStr) .histSQL(histSQL) - .llmConfig(queryTextReq.getLlmConfig()) + .modelConfig(queryTextReq.getModelConfig()) .build()); chatParseContext.setQueryText(rewrittenQuery); log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", @@ -181,7 +181,7 @@ public class NL2SQLParser implements ChatParser { Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr); - ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(context.getLlmConfig()); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(context.getModelConfig()); Response response = chatLanguageModel.generate(prompt.toUserMessage()); String result = response.content().text(); @@ -243,7 +243,7 @@ public class NL2SQLParser implements ChatParser { private String curtSchema; private String histSchema; private String histSQL; - private ChatModelConfig llmConfig; + private ModelConfig modelConfig; } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java index 62314ff49..8be3968e3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java @@ -11,15 +11,18 @@ import java.util.Date; @TableName("s2_agent") public class AgentDO { /** + * */ @TableId(type = IdType.AUTO) private Integer id; /** + * */ private String name; /** + * */ private String description; @@ -29,35 +32,40 @@ public class AgentDO { private Integer status; /** + * */ private String examples; /** + * */ private String config; /** + * */ private String createdBy; /** + * */ private Date createdAt; /** + * */ private String updatedBy; /** + * */ private Date updatedAt; /** + * */ private Integer enableSearch; - - private String llmConfig; - + private String modelConfig; private String multiTurnConfig; private String visualConfig; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java index eb71f8641..179bc28f7 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java @@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.util.LLMConnHelper; -import com.tencent.supersonic.common.config.ChatModelConfig; +import com.tencent.supersonic.common.config.ModelConfig; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.PathVariable; @@ -30,16 +30,16 @@ public class AgentController { @PostMapping public Agent createAgent(@RequestBody Agent agent, - HttpServletRequest httpServletRequest, - HttpServletResponse httpServletResponse) { + HttpServletRequest httpServletRequest, + HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); return agentService.createAgent(agent, user); } @PutMapping public Agent updateAgent(@RequestBody Agent agent, - HttpServletRequest httpServletRequest, - HttpServletResponse httpServletResponse) { + HttpServletRequest httpServletRequest, + HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); return agentService.updateAgent(agent, user); } @@ -51,8 +51,8 @@ public class AgentController { } @PostMapping("/testLLMConn") - public boolean testLLMConn(@RequestBody ChatModelConfig llmConfig) { - return LLMConnHelper.testConnection(llmConfig); + public boolean testLLMConn(@RequestBody ModelConfig modelConfig) { + return LLMConnHelper.testConnection(modelConfig); } @RequestMapping("/getAgentList") diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index eef3acc18..737953230 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -12,7 +12,7 @@ import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatService; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.util.LLMConnHelper; -import com.tencent.supersonic.common.config.ChatModelConfig; +import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.config.VisualConfig; import com.tencent.supersonic.common.util.JsonUtil; @@ -88,7 +88,7 @@ public class AgentServiceImpl extends ServiceImpl } private synchronized void doExecuteAgentExamples(Agent agent) { - if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getLlmConfig()) + if (!agent.containsLLMParserTool() || !LLMConnHelper.testConnection(agent.getModelConfig()) || CollectionUtils.isEmpty(agent.getExamples())) { return; } @@ -122,7 +122,7 @@ public class AgentServiceImpl extends ServiceImpl BeanUtils.copyProperties(agentDO, agent); agent.setAgentConfig(agentDO.getConfig()); agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class)); - agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), ChatModelConfig.class)); + agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ModelConfig.class)); agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class)); agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class)); agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class)); @@ -134,7 +134,7 @@ public class AgentServiceImpl extends ServiceImpl BeanUtils.copyProperties(agent, agentDO); agentDO.setConfig(agent.getAgentConfig()); agentDO.setExamples(JsonUtil.toString(agent.getExamples())); - agentDO.setLlmConfig(JsonUtil.toString(agent.getLlmConfig())); + agentDO.setModelConfig(JsonUtil.toString(agent.getModelConfig())); agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig())); agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig())); agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig())); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java index e65261dbf..50202ae50 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.chat.server.util; -import com.tencent.supersonic.common.config.ChatModelConfig; +import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.provider.ModelProvider; @@ -9,12 +9,13 @@ import org.apache.commons.lang3.StringUtils; @Slf4j public class LLMConnHelper { - public static boolean testConnection(ChatModelConfig chatModel) { + public static boolean testConnection(ModelConfig modelConfig) { try { - if (chatModel == null || StringUtils.isBlank(chatModel.getBaseUrl())) { + if (modelConfig == null || modelConfig.getChatModel() == null + || StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) { return false; } - ChatLanguageModel chatLanguageModel = ModelProvider.provideChatModel(chatModel); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig); String response = chatLanguageModel.generate("Hi there"); return StringUtils.isNotEmpty(response) ? true : false; } catch (Exception e) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index 6a1203306..8cb1484c0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -30,7 +30,7 @@ public class QueryReqConverter { && MapUtils.isNotEmpty(queryTextReq.getMapInfo().getDataSetElementMatches())) { queryTextReq.setMapInfo(queryTextReq.getMapInfo()); } - queryTextReq.setLlmConfig(agent.getLlmConfig()); + queryTextReq.setModelConfig(agent.getModelConfig()); queryTextReq.setPromptConfig(agent.getPromptConfig()); return queryTextReq; } diff --git a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java index 5e47576cc..d22b19eed 100644 --- a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java @@ -1,15 +1,23 @@ package dev.langchain4j.chroma.spring; +import com.tencent.supersonic.common.config.EmbeddingStoreConfig; import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.chroma.ChromaEmbeddingStore; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.BeanUtils; + +import java.time.Duration; @Slf4j public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { private Properties properties; + public ChromaEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) { + this(createPropertiesFromConfig(storeConfig)); + } + public ChromaEmbeddingStoreFactory(Properties properties) { this.properties = properties; } @@ -23,4 +31,13 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { .timeout(storeProperties.getTimeout()) .build(); } + + private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) { + Properties properties = new Properties(); + EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties(); + BeanUtils.copyProperties(storeConfig, embeddingStore); + embeddingStore.setTimeout(Duration.ofSeconds(storeConfig.getTimeOut())); + properties.setEmbeddingStore(embeddingStore); + return properties; + } } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/chroma/spring/EmbeddingStoreProperties.java b/common/src/main/java/dev/langchain4j/chroma/spring/EmbeddingStoreProperties.java index 0603dcdaa..b30bdb252 100644 --- a/common/src/main/java/dev/langchain4j/chroma/spring/EmbeddingStoreProperties.java +++ b/common/src/main/java/dev/langchain4j/chroma/spring/EmbeddingStoreProperties.java @@ -7,7 +7,7 @@ import java.time.Duration; @Getter @Setter -class EmbeddingStoreProperties { +public class EmbeddingStoreProperties { private String baseUrl; private Duration timeout; diff --git a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java index 05bee26ba..21e0690cf 100644 --- a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java @@ -1,6 +1,7 @@ package dev.langchain4j.inmemory.spring; import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.config.EmbeddingStoreConfig; import com.tencent.supersonic.common.util.ContextUtils; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; @@ -9,6 +10,7 @@ import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.MapUtils; import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.BeanUtils; import java.nio.file.Files; import java.nio.file.Path; @@ -23,11 +25,22 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { public static final String PERSISTENT_FILE_PRE = "InMemory."; private Properties properties; + public InMemoryEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) { + this(createPropertiesFromConfig(storeConfig)); + } public InMemoryEmbeddingStoreFactory(Properties properties) { this.properties = properties; } + private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) { + Properties properties = new Properties(); + EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties(); + BeanUtils.copyProperties(storeConfig, embeddingStore); + properties.setEmbeddingStore(embeddingStore); + return properties; + } + @Override public synchronized EmbeddingStore createEmbeddingStore(String collectionName) { InMemoryEmbeddingStore embeddingStore = reloadFromPersistFile(collectionName); diff --git a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java index 873c5f129..b3d86af5b 100644 --- a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java @@ -1,17 +1,32 @@ package dev.langchain4j.milvus.spring; +import com.tencent.supersonic.common.config.EmbeddingStoreConfig; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.BaseEmbeddingStoreFactory; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.milvus.MilvusEmbeddingStore; +import org.springframework.beans.BeanUtils; public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { private final Properties properties; + public MilvusEmbeddingStoreFactory(EmbeddingStoreConfig storeConfig) { + this(createPropertiesFromConfig(storeConfig)); + } + public MilvusEmbeddingStoreFactory(Properties properties) { this.properties = properties; } + private static Properties createPropertiesFromConfig(EmbeddingStoreConfig storeConfig) { + Properties properties = new Properties(); + EmbeddingStoreProperties embeddingStore = new EmbeddingStoreProperties(); + BeanUtils.copyProperties(storeConfig, embeddingStore); + embeddingStore.setUri(storeConfig.getBaseUrl()); + properties.setEmbeddingStore(embeddingStore); + return properties; + } + @Override public EmbeddingStore createEmbeddingStore(String collectionName) { EmbeddingStoreProperties storeProperties = properties.getEmbeddingStore(); diff --git a/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java b/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java index d32cb030f..ecf381090 100644 --- a/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java @@ -13,15 +13,16 @@ import java.time.Duration; @Service public class AzureModelFactory implements ModelFactory, InitializingBean { + public static final String PROVIDER = "AZURE"; @Override - public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { + public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder() - .endpoint(chatModel.getBaseUrl()) - .apiKey(chatModel.getApiKey()) - .deploymentName(chatModel.getModelName()) - .temperature(chatModel.getTemperature()) - .timeout(Duration.ofSeconds(chatModel.getTimeOut() == null ? 0L : chatModel.getTimeOut())); + .endpoint(modelConfig.getBaseUrl()) + .apiKey(modelConfig.getApiKey()) + .deploymentName(modelConfig.getModelName()) + .temperature(modelConfig.getTemperature()) + .timeout(Duration.ofSeconds(modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut())); return builder.build(); } @@ -39,6 +40,6 @@ public class AzureModelFactory implements ModelFactory, InitializingBean { @Override public void afterPropertiesSet() { - ModelProvider.add(Provider.AZURE, this); + ModelProvider.add(PROVIDER, this); } } diff --git a/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java b/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java index 8cb0b952a..6fcce7d8a 100644 --- a/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java @@ -11,14 +11,16 @@ import org.springframework.stereotype.Service; @Service public class DashscopeModelFactory implements ModelFactory, InitializingBean { + public static final String PROVIDER = "DASHSCOPE"; + @Override - public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { + public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { return QwenChatModel.builder() - .baseUrl(chatModel.getBaseUrl()) - .apiKey(chatModel.getApiKey()) - .modelName(chatModel.getModelName()) - .temperature(chatModel.getTemperature() == null ? 0L : - chatModel.getTemperature().floatValue()) + .baseUrl(modelConfig.getBaseUrl()) + .apiKey(modelConfig.getApiKey()) + .modelName(modelConfig.getModelName()) + .temperature(modelConfig.getTemperature() == null ? 0L : + modelConfig.getTemperature().floatValue()) .build(); } @@ -32,6 +34,6 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean { @Override public void afterPropertiesSet() { - ModelProvider.add(Provider.DASHSCOPE, this); + ModelProvider.add(PROVIDER, this); } } diff --git a/common/src/main/java/dev/langchain4j/provider/InMemoryModelFactory.java b/common/src/main/java/dev/langchain4j/provider/InMemoryModelFactory.java index 5505ca032..444b0d836 100644 --- a/common/src/main/java/dev/langchain4j/provider/InMemoryModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/InMemoryModelFactory.java @@ -16,9 +16,11 @@ import static dev.langchain4j.inmemory.spring.InMemoryAutoConfig.BGE_SMALL_ZH; @Service public class InMemoryModelFactory implements ModelFactory, InitializingBean { + public static final String PROVIDER = "IN_MEMORY"; + @Override - public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { - return null; + public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { + throw new UnsupportedOperationException("Not supported yet."); } @Override @@ -40,6 +42,6 @@ public class InMemoryModelFactory implements ModelFactory, InitializingBean { @Override public void afterPropertiesSet() { - ModelProvider.add(Provider.IN_MEMORY, this); + ModelProvider.add(PROVIDER, this); } } diff --git a/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java b/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java index ccd0a2abd..496035061 100644 --- a/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java @@ -13,14 +13,16 @@ import java.time.Duration; @Service public class LocalAiModelFactory implements ModelFactory, InitializingBean { + public static final String PROVIDER = "LOCAL_AI"; + @Override - public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { + public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { return LocalAiChatModel .builder() - .baseUrl(chatModel.getBaseUrl()) - .modelName(chatModel.getModelName()) - .temperature(chatModel.getTemperature()) - .timeout(Duration.ofSeconds(chatModel.getTimeOut())) + .baseUrl(modelConfig.getBaseUrl()) + .modelName(modelConfig.getModelName()) + .temperature(modelConfig.getTemperature()) + .timeout(Duration.ofSeconds(modelConfig.getTimeOut())) .build(); } @@ -34,6 +36,6 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean { @Override public void afterPropertiesSet() { - ModelProvider.add(Provider.LOCAL_AI, this); + ModelProvider.add(PROVIDER, this); } } \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/provider/ModelFactory.java b/common/src/main/java/dev/langchain4j/provider/ModelFactory.java index 0c89b1ebc..660ae13d6 100644 --- a/common/src/main/java/dev/langchain4j/provider/ModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/ModelFactory.java @@ -6,7 +6,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.embedding.EmbeddingModel; public interface ModelFactory { - ChatLanguageModel createChatModel(ChatModelConfig llmConfig); + ChatLanguageModel createChatModel(ChatModelConfig modelConfig); EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel); } diff --git a/common/src/main/java/dev/langchain4j/provider/ModelProvider.java b/common/src/main/java/dev/langchain4j/provider/ModelProvider.java index 5ef928001..cb0d7332a 100644 --- a/common/src/main/java/dev/langchain4j/provider/ModelProvider.java +++ b/common/src/main/java/dev/langchain4j/provider/ModelProvider.java @@ -15,24 +15,11 @@ import java.util.Objects; public class ModelProvider { private static final Map factories = new HashMap<>(); - public static void add(Provider provider, ModelFactory modelFactory) { - factories.put(provider.name(), modelFactory); + public static void add(String provider, ModelFactory modelFactory) { + factories.put(provider, modelFactory); } - public static ChatLanguageModel provideChatModel(ChatModelConfig llmConfig) { - if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider()) - || StringUtils.isBlank(llmConfig.getBaseUrl())) { - return ContextUtils.getBean(ChatLanguageModel.class); - } - ModelFactory modelFactory = factories.get(llmConfig.getProvider().toUpperCase()); - if (modelFactory != null) { - return modelFactory.createChatModel(llmConfig); - } - - throw new RuntimeException("Unsupported ChatLanguageModel provider: " + llmConfig.getProvider()); - } - - public static ChatLanguageModel provideChatModelNew(ModelConfig modelConfig) { + public static ChatLanguageModel getChatModel(ModelConfig modelConfig) { if (modelConfig == null || modelConfig.getChatModel() == null || StringUtils.isBlank(modelConfig.getChatModel().getProvider()) || StringUtils.isBlank(modelConfig.getChatModel().getBaseUrl())) { @@ -47,7 +34,7 @@ public class ModelProvider { throw new RuntimeException("Unsupported ChatLanguageModel provider: " + chatModel.getProvider()); } - public static EmbeddingModel provideEmbeddingModel(ModelConfig modelConfig) { + public static EmbeddingModel getEmbeddingModel(ModelConfig modelConfig) { if (modelConfig == null || Objects.isNull(modelConfig.getEmbeddingModel()) || StringUtils.isBlank(modelConfig.getEmbeddingModel().getBaseUrl()) || StringUtils.isBlank(modelConfig.getEmbeddingModel().getProvider())) { diff --git a/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java b/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java index 489fd03db..e027eae83 100644 --- a/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java @@ -13,14 +13,16 @@ import java.time.Duration; @Service public class OllamaModelFactory implements ModelFactory, InitializingBean { + public static final String PROVIDER = "OLLAMA"; + @Override - public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { + public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { return OllamaChatModel .builder() - .baseUrl(chatModel.getBaseUrl()) - .modelName(chatModel.getModelName()) - .temperature(chatModel.getTemperature()) - .timeout(Duration.ofSeconds(chatModel.getTimeOut())) + .baseUrl(modelConfig.getBaseUrl()) + .modelName(modelConfig.getModelName()) + .temperature(modelConfig.getTemperature()) + .timeout(Duration.ofSeconds(modelConfig.getTimeOut())) .build(); } @@ -37,6 +39,6 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean { @Override public void afterPropertiesSet() { - ModelProvider.add(Provider.OLLAMA, this); + ModelProvider.add(PROVIDER, this); } } diff --git a/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java b/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java index 352c388d2..b48bae951 100644 --- a/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java @@ -13,15 +13,17 @@ import java.time.Duration; @Service public class OpenAiModelFactory implements ModelFactory, InitializingBean { + public static final String PROVIDER = "OPEN_AI"; + @Override - public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { + public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { return OpenAiChatModel .builder() - .baseUrl(chatModel.getBaseUrl()) - .modelName(chatModel.getModelName()) - .apiKey(chatModel.keyDecrypt()) - .temperature(chatModel.getTemperature()) - .timeout(Duration.ofSeconds(chatModel.getTimeOut())) + .baseUrl(modelConfig.getBaseUrl()) + .modelName(modelConfig.getModelName()) + .apiKey(modelConfig.keyDecrypt()) + .temperature(modelConfig.getTemperature()) + .timeout(Duration.ofSeconds(modelConfig.getTimeOut())) .build(); } @@ -39,6 +41,6 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean { @Override public void afterPropertiesSet() { - ModelProvider.add(Provider.OPEN_AI, this); + ModelProvider.add(PROVIDER, this); } } diff --git a/common/src/main/java/dev/langchain4j/provider/Provider.java b/common/src/main/java/dev/langchain4j/provider/Provider.java deleted file mode 100644 index 8e7b2466e..000000000 --- a/common/src/main/java/dev/langchain4j/provider/Provider.java +++ /dev/null @@ -1,12 +0,0 @@ -package dev.langchain4j.provider; - -public enum Provider { - OPEN_AI, - OLLAMA, - LOCAL_AI, - IN_MEMORY, - ZHIPU, - AZURE, - QIANFAN, - DASHSCOPE -} diff --git a/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java b/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java index dee6567d4..9051e243d 100644 --- a/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java @@ -10,8 +10,11 @@ import org.springframework.stereotype.Service; @Service public class QianfanModelFactory implements ModelFactory, InitializingBean { + + public static final String PROVIDER = "QIANFAN"; + @Override - public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { + public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { return null; } @@ -29,6 +32,6 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean { @Override public void afterPropertiesSet() { - ModelProvider.add(Provider.QIANFAN, this); + ModelProvider.add(PROVIDER, this); } } diff --git a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java index 640dd4bf1..f2c4bfa0b 100644 --- a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java @@ -10,8 +10,10 @@ import org.springframework.stereotype.Service; @Service public class ZhipuModelFactory implements ModelFactory, InitializingBean { + public static final String PROVIDER = "ZHIPU"; + @Override - public ChatLanguageModel createChatModel(ChatModelConfig chatModel) { + public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { return null; } @@ -29,6 +31,6 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean { @Override public void afterPropertiesSet() { - ModelProvider.add(Provider.ZHIPU, this); + ModelProvider.add(PROVIDER, this); } } diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java new file mode 100644 index 000000000..f34f825e4 --- /dev/null +++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java @@ -0,0 +1,26 @@ +package dev.langchain4j.store.embedding; + +import com.tencent.supersonic.common.config.EmbeddingStoreConfig; +import com.tencent.supersonic.common.util.ContextUtils; +import dev.langchain4j.chroma.spring.ChromaEmbeddingStoreFactory; +import dev.langchain4j.inmemory.spring.InMemoryEmbeddingStoreFactory; +import dev.langchain4j.milvus.spring.MilvusEmbeddingStoreFactory; +import org.apache.commons.lang3.StringUtils; + +public class EmbeddingStoreFactoryProvider { + public static EmbeddingStoreFactory getFactory(EmbeddingStoreConfig storeConfig) { + if (storeConfig == null || StringUtils.isBlank(storeConfig.getProvider())) { + return ContextUtils.getBean(EmbeddingStoreFactory.class); + } + if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(storeConfig.getProvider())) { + return new ChromaEmbeddingStoreFactory(storeConfig); + } + if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(storeConfig.getProvider())) { + return new MilvusEmbeddingStoreFactory(storeConfig); + } + if (EmbeddingStoreType.IN_MEMORY.name().equalsIgnoreCase(storeConfig.getProvider())) { + return new InMemoryEmbeddingStoreFactory(storeConfig); + } + throw new RuntimeException("Unsupported EmbeddingStore provider: " + storeConfig.getProvider()); + } +} \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/provider/EmbeddingStoreType.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java similarity index 64% rename from common/src/main/java/dev/langchain4j/provider/EmbeddingStoreType.java rename to common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java index ce6db95e4..7e2e0e3b3 100644 --- a/common/src/main/java/dev/langchain4j/provider/EmbeddingStoreType.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java @@ -1,4 +1,4 @@ -package dev.langchain4j.provider; +package dev.langchain4j.store.embedding; public enum EmbeddingStoreType { IN_MEMORY, diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryTextReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryTextReq.java index ba6c32073..ca8851141 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryTextReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryTextReq.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request; import com.google.common.collect.Lists; import com.google.common.collect.Sets; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.common.config.ChatModelConfig; +import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.PromptConfig; import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; @@ -27,7 +27,7 @@ public class QueryTextReq { private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private SchemaMapInfo mapInfo = new SchemaMapInfo(); private QueryDataType queryDataType = QueryDataType.ALL; - private ChatModelConfig llmConfig; + private ModelConfig modelConfig; private PromptConfig promptConfig; private List dynamicExemplars = Lists.newArrayList(); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/QueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/QueryContext.java index 5aa5ae4a6..fdd1bd114 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/QueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/QueryContext.java @@ -2,15 +2,14 @@ package com.tencent.supersonic.headless.chat; import com.fasterxml.jackson.annotation.JsonIgnore; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.PromptConfig; +import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.QueryDataType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; -import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; @@ -21,7 +20,6 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; - import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -52,7 +50,6 @@ public class QueryContext { private WorkflowState workflowState; private QueryDataType queryDataType = QueryDataType.ALL; private ModelConfig modelConfig; - private ChatModelConfig llmConfig; private PromptConfig promptConfig; private List dynamicExemplars; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 3dbd829dd..227a45ad0 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -103,7 +103,6 @@ public class LLMRequestService { llmReq.setCurrentDate(DateUtils.getBeforeDate(0)); llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE))); llmReq.setModelConfig(queryCtx.getModelConfig()); - llmReq.setLlmConfig(queryCtx.getLlmConfig()); llmReq.setPromptConfig(queryCtx.getPromptConfig()); llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index ad8ae6858..f9c33e83c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -55,7 +55,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { Map prompt2Output = new ConcurrentHashMap<>(); prompt2Exemplar.keySet().parallelStream().forEach(prompt -> { keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage()); - ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig()); + ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig()); Response response = chatLanguageModel.generate(prompt.toUserMessage()); String result = response.content().text(); prompt2Output.put(prompt, result); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java index 180ac03f7..795ec6c73 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.headless.chat.parser.llm; -import com.tencent.supersonic.common.config.ChatModelConfig; +import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import dev.langchain4j.model.chat.ChatLanguageModel; @@ -23,8 +23,8 @@ public abstract class SqlGenStrategy implements InitializingBean { @Autowired protected PromptHelper promptHelper; - protected ChatLanguageModel getChatLanguageModel(ChatModelConfig llmConfig) { - return ModelProvider.provideChatModel(llmConfig); + protected ChatLanguageModel getChatLanguageModel(ModelConfig modelConfig) { + return ModelProvider.getChatModel(modelConfig); } abstract LLMResp generate(LLMReq llmReq); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java index 28855f67e..65d155030 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java @@ -2,11 +2,10 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql; import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.collect.Lists; -import com.tencent.supersonic.common.config.ChatModelConfig; import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.config.PromptConfig; -import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; import lombok.Data; import java.util.List; @@ -29,7 +28,6 @@ public class LLMReq { private SqlGenType sqlGenType; private ModelConfig modelConfig; - private ChatModelConfig llmConfig; private PromptConfig promptConfig; private List dynamicExemplars; diff --git a/launchers/standalone/src/main/resources/config.update/sql-update.sql b/launchers/standalone/src/main/resources/config.update/sql-update.sql index 37f090a0c..bc746360e 100644 --- a/launchers/standalone/src/main/resources/config.update/sql-update.sql +++ b/launchers/standalone/src/main/resources/config.update/sql-update.sql @@ -350,4 +350,7 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` ( ) ENGINE=InnoDB DEFAULT CHARSET=utf8; --20240705 -alter table s2_agent add column `prompt_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL COMMENT '提示词配置'; \ No newline at end of file +alter table s2_agent add column `prompt_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL COMMENT '提示词配置'; + +--20240707 +alter table s2_agent add model_config varchar(6000) null; \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index e115da152..8dda9fb36 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -374,6 +374,7 @@ CREATE TABLE IF NOT EXISTS s2_agent examples varchar(500) null, config varchar(2000) null, llm_config varchar(2000) null, + model_config varchar(6000) null, prompt_config varchar(5000) null, multi_turn_config varchar(2000) null, visual_config varchar(2000) null, diff --git a/launchers/standalone/src/main/resources/db/schema-mysql.sql b/launchers/standalone/src/main/resources/db/schema-mysql.sql index 51c7f2796..827f54d73 100644 --- a/launchers/standalone/src/main/resources/db/schema-mysql.sql +++ b/launchers/standalone/src/main/resources/db/schema-mysql.sql @@ -73,6 +73,7 @@ CREATE TABLE `s2_agent` ( `model` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL, `config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL, `llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, + `model_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL, `multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `created_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL, diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index 0313701b2..2d2ff9090 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -87,7 +87,6 @@ public class Text2SQLEval extends BaseTest { agentConfig.getTools().add(getLLMQueryTool()); agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); agent.setModelConfig(getLLMConfig(LLMType.GPT)); - agent.setLlmConfig(getLLMConfig(LLMType.GPT).getChatModel()); MultiTurnConfig multiTurnConfig = new MultiTurnConfig(); multiTurnConfig.setEnableMultiTurn(enableMultiturn); agent.setMultiTurnConfig(multiTurnConfig); diff --git a/launchers/standalone/src/test/resources/db/schema-h2.sql b/launchers/standalone/src/test/resources/db/schema-h2.sql index 8ba9c4ba5..6d5823f35 100644 --- a/launchers/standalone/src/test/resources/db/schema-h2.sql +++ b/launchers/standalone/src/test/resources/db/schema-h2.sql @@ -374,6 +374,7 @@ CREATE TABLE IF NOT EXISTS s2_agent examples varchar(500) null, config varchar(2000) null, llm_config varchar(2000) null, + model_config varchar(6000) null, prompt_config varchar(5000) null, multi_turn_config varchar(2000) null, visual_config varchar(2000) null, @@ -383,7 +384,7 @@ CREATE TABLE IF NOT EXISTS s2_agent updated_at TIMESTAMP null, enable_search int null, PRIMARY KEY (`id`) -); COMMENT ON TABLE s2_agent IS 'agent information table'; + ); COMMENT ON TABLE s2_agent IS 'agent information table'; -------demo for semantic and chat