diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java index 8a86069d1..dbd748171 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -7,7 +7,6 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryReposi import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.S2ChatModelProvider; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryState; @@ -16,6 +15,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.provider.ChatLanguageModelProvider; import java.util.Collections; import java.util.List; @@ -46,7 +46,7 @@ public class PlainTextExecutor implements ChatExecutor { AgentService agentService = ContextUtils.getBean(AgentService.class); Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId()); - ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig()); + ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(chatAgent.getLlmConfig()); Response response = chatLanguageModel.generate(prompt.toUserMessage()); QueryResult result = new QueryResult(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java index 94f382cf0..d4968c202 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -4,10 +4,10 @@ import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.MemoryService; -import com.tencent.supersonic.common.util.S2ChatModelProvider; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.model.provider.ChatLanguageModelProvider; import lombok.extern.slf4j.Slf4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -54,7 +54,7 @@ public class MemoryReviewTask { Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr); - ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig()); + ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(chatAgent.getLlmConfig()); String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text(); keyPipelineLog.info("MemoryReviewTask modelResp:{}", response); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index cb8628cf1..ab13a879a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -8,7 +8,6 @@ import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.S2ChatModelProvider; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; @@ -24,6 +23,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.output.Response; +import dev.langchain4j.model.provider.ChatLanguageModelProvider; import lombok.Builder; import lombok.Data; import lombok.extern.slf4j.Slf4j; @@ -180,7 +180,7 @@ public class NL2SQLParser implements ChatParser { Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr); - ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig()); + ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(context.getLlmConfig()); Response response = chatLanguageModel.generate(prompt.toUserMessage()); String result = response.content().text(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java index a49311d41..5b78affb3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java @@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.server.util; import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; -import com.tencent.supersonic.common.util.S2ChatModelProvider; import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.provider.ChatLanguageModelProvider; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; @@ -14,7 +14,7 @@ public class LLMConnHelper { if (llmConfig == null || StringUtils.isBlank(llmConfig.getBaseUrl())) { return false; } - ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(llmConfig); + ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(llmConfig); String response = chatLanguageModel.generate("Hi there"); return StringUtils.isNotEmpty(response) ? true : false; } catch (Exception e) { diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/S2ModelProvider.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/S2ModelProvider.java deleted file mode 100644 index ba1065886..000000000 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/S2ModelProvider.java +++ /dev/null @@ -1,9 +0,0 @@ -package com.tencent.supersonic.common.pojo.enums; - -public enum S2ModelProvider { - - OPEN_AI, - HUGGING_FACE, - LOCAL_AI, - IN_PROCESS -} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/S2ChatModelProvider.java b/common/src/main/java/com/tencent/supersonic/common/util/S2ChatModelProvider.java index e31292c49..e69de29bb 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/S2ChatModelProvider.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/S2ChatModelProvider.java @@ -1,41 +0,0 @@ -package com.tencent.supersonic.common.util; - -import com.tencent.supersonic.common.config.LLMConfig; -import com.tencent.supersonic.common.pojo.enums.S2ModelProvider; -import dev.langchain4j.model.chat.ChatLanguageModel; -import dev.langchain4j.model.localai.LocalAiChatModel; -import dev.langchain4j.model.openai.OpenAiChatModel; -import org.apache.commons.lang3.StringUtils; - -import java.time.Duration; - -public class S2ChatModelProvider { - - public static ChatLanguageModel provide(LLMConfig llmConfig) { - ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class); - if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider()) - || StringUtils.isBlank(llmConfig.getBaseUrl())) { - return chatLanguageModel; - } - if (S2ModelProvider.OPEN_AI.name().equalsIgnoreCase(llmConfig.getProvider())) { - return OpenAiChatModel - .builder() - .baseUrl(llmConfig.getBaseUrl()) - .modelName(llmConfig.getModelName()) - .apiKey(llmConfig.keyDecrypt()) - .temperature(llmConfig.getTemperature()) - .timeout(Duration.ofSeconds(llmConfig.getTimeOut())) - .build(); - } else if (S2ModelProvider.LOCAL_AI.name().equalsIgnoreCase(llmConfig.getProvider())) { - return LocalAiChatModel - .builder() - .baseUrl(llmConfig.getBaseUrl()) - .modelName(llmConfig.getModelName()) - .temperature(llmConfig.getTemperature()) - .timeout(Duration.ofSeconds(llmConfig.getTimeOut())) - .build(); - } - throw new RuntimeException("unsupported provider: " + llmConfig.getProvider()); - } - -} diff --git a/common/src/main/java/dev/langchain4j/model/provider/ChatLanguageModelFactory.java b/common/src/main/java/dev/langchain4j/model/provider/ChatLanguageModelFactory.java new file mode 100644 index 000000000..ef081dfff --- /dev/null +++ b/common/src/main/java/dev/langchain4j/model/provider/ChatLanguageModelFactory.java @@ -0,0 +1,8 @@ +package dev.langchain4j.model.provider; + +import com.tencent.supersonic.common.config.LLMConfig; +import dev.langchain4j.model.chat.ChatLanguageModel; + +public interface ChatLanguageModelFactory { + ChatLanguageModel create(LLMConfig llmConfig); +} diff --git a/common/src/main/java/dev/langchain4j/model/provider/ChatLanguageModelProvider.java b/common/src/main/java/dev/langchain4j/model/provider/ChatLanguageModelProvider.java new file mode 100644 index 000000000..99febb596 --- /dev/null +++ b/common/src/main/java/dev/langchain4j/model/provider/ChatLanguageModelProvider.java @@ -0,0 +1,33 @@ +package dev.langchain4j.model.provider; + +import com.tencent.supersonic.common.config.LLMConfig; +import com.tencent.supersonic.common.util.ContextUtils; +import dev.langchain4j.model.chat.ChatLanguageModel; +import org.apache.commons.lang3.StringUtils; + +import java.util.HashMap; +import java.util.Map; + +public class ChatLanguageModelProvider { + private static final Map factories = new HashMap<>(); + + static { + factories.put(ModelProvider.OPEN_AI.name(), new OpenAiChatModelFactory()); + factories.put(ModelProvider.LOCAL_AI.name(), new LocalAiChatModelFactory()); + factories.put(ModelProvider.OLLAMA.name(), new OllamaChatModelFactory()); + } + + public static ChatLanguageModel provide(LLMConfig llmConfig) { + if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider()) + || StringUtils.isBlank(llmConfig.getBaseUrl())) { + return ContextUtils.getBean(ChatLanguageModel.class); + } + + ChatLanguageModelFactory factory = factories.get(llmConfig.getProvider().toUpperCase()); + if (factory != null) { + return factory.create(llmConfig); + } + + throw new RuntimeException("Unsupported provider: " + llmConfig.getProvider()); + } +} \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/model/provider/LocalAiChatModelFactory.java b/common/src/main/java/dev/langchain4j/model/provider/LocalAiChatModelFactory.java new file mode 100644 index 000000000..66099c888 --- /dev/null +++ b/common/src/main/java/dev/langchain4j/model/provider/LocalAiChatModelFactory.java @@ -0,0 +1,20 @@ +package dev.langchain4j.model.provider; + +import com.tencent.supersonic.common.config.LLMConfig; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.localai.LocalAiChatModel; + +import java.time.Duration; + +public class LocalAiChatModelFactory implements ChatLanguageModelFactory { + @Override + public ChatLanguageModel create(LLMConfig llmConfig) { + return LocalAiChatModel + .builder() + .baseUrl(llmConfig.getBaseUrl()) + .modelName(llmConfig.getModelName()) + .temperature(llmConfig.getTemperature()) + .timeout(Duration.ofSeconds(llmConfig.getTimeOut())) + .build(); + } +} \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/model/provider/ModelProvider.java b/common/src/main/java/dev/langchain4j/model/provider/ModelProvider.java new file mode 100644 index 000000000..33c88845e --- /dev/null +++ b/common/src/main/java/dev/langchain4j/model/provider/ModelProvider.java @@ -0,0 +1,9 @@ +package dev.langchain4j.model.provider; + +public enum ModelProvider { + OPEN_AI, + HUGGING_FACE, + LOCAL_AI, + IN_PROCESS, + OLLAMA +} diff --git a/common/src/main/java/dev/langchain4j/model/provider/OllamaChatModelFactory.java b/common/src/main/java/dev/langchain4j/model/provider/OllamaChatModelFactory.java new file mode 100644 index 000000000..9cb89c75f --- /dev/null +++ b/common/src/main/java/dev/langchain4j/model/provider/OllamaChatModelFactory.java @@ -0,0 +1,20 @@ +package dev.langchain4j.model.provider; + +import com.tencent.supersonic.common.config.LLMConfig; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.ollama.OllamaChatModel; + +import java.time.Duration; + +public class OllamaChatModelFactory implements ChatLanguageModelFactory { + @Override + public ChatLanguageModel create(LLMConfig llmConfig) { + return OllamaChatModel + .builder() + .baseUrl(llmConfig.getBaseUrl()) + .modelName(llmConfig.getModelName()) + .temperature(llmConfig.getTemperature()) + .timeout(Duration.ofSeconds(llmConfig.getTimeOut())) + .build(); + } +} \ No newline at end of file diff --git a/common/src/main/java/dev/langchain4j/model/provider/OpenAiChatModelFactory.java b/common/src/main/java/dev/langchain4j/model/provider/OpenAiChatModelFactory.java new file mode 100644 index 000000000..c21653de8 --- /dev/null +++ b/common/src/main/java/dev/langchain4j/model/provider/OpenAiChatModelFactory.java @@ -0,0 +1,21 @@ +package dev.langchain4j.model.provider; + +import com.tencent.supersonic.common.config.LLMConfig; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.openai.OpenAiChatModel; + +import java.time.Duration; + +public class OpenAiChatModelFactory implements ChatLanguageModelFactory { + @Override + public ChatLanguageModel create(LLMConfig llmConfig) { + return OpenAiChatModel + .builder() + .baseUrl(llmConfig.getBaseUrl()) + .modelName(llmConfig.getModelName()) + .apiKey(llmConfig.keyDecrypt()) + .temperature(llmConfig.getTemperature()) + .timeout(Duration.ofSeconds(llmConfig.getTimeOut())) + .build(); + } +} \ No newline at end of file diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java index 861b0c7db..03343544f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java @@ -3,8 +3,8 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; -import com.tencent.supersonic.common.util.S2ChatModelProvider; import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.provider.ChatLanguageModelProvider; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.InitializingBean; @@ -24,7 +24,7 @@ public abstract class SqlGenStrategy implements InitializingBean { protected PromptHelper promptHelper; protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) { - return S2ChatModelProvider.provide(llmConfig); + return ChatLanguageModelProvider.provide(llmConfig); } abstract LLMResp generate(LLMReq llmReq); diff --git a/launchers/standalone/src/main/resources/langchain4j-config.yaml b/launchers/standalone/src/main/resources/langchain4j-config.yaml index 703656283..38e2da7a9 100644 --- a/launchers/standalone/src/main/resources/langchain4j-config.yaml +++ b/launchers/standalone/src/main/resources/langchain4j-config.yaml @@ -13,4 +13,25 @@ langchain4j: embedding-model: model-name: bge-small-zh embedding-store: - persist-path: /tmp \ No newline at end of file + persist-path: /tmp +# ollama: +# chat-model: +# base-url: http://localhost:11434 +# api-key: demo +# model-name: qwen:0.5b +# temperature: 0.0 +# timeout: PT60S + +# chroma: +# embedding-store: +# baseUrl: http://0.0.0.0:8000 +# timeout: 120s + +# milvus: +# embedding-store: +# host: localhost +# port: 2379 +# uri: http://0.0.0.0:19530 +# token: demo +# dimension: 512 +# timeout: 120s \ No newline at end of file diff --git a/launchers/standalone/src/test/resources/langchain4j-config.yaml b/launchers/standalone/src/test/resources/langchain4j-config.yaml index 703656283..38e2da7a9 100644 --- a/launchers/standalone/src/test/resources/langchain4j-config.yaml +++ b/launchers/standalone/src/test/resources/langchain4j-config.yaml @@ -13,4 +13,25 @@ langchain4j: embedding-model: model-name: bge-small-zh embedding-store: - persist-path: /tmp \ No newline at end of file + persist-path: /tmp +# ollama: +# chat-model: +# base-url: http://localhost:11434 +# api-key: demo +# model-name: qwen:0.5b +# temperature: 0.0 +# timeout: PT60S + +# chroma: +# embedding-store: +# baseUrl: http://0.0.0.0:8000 +# timeout: 120s + +# milvus: +# embedding-store: +# host: localhost +# port: 2379 +# uri: http://0.0.0.0:19530 +# token: demo +# dimension: 512 +# timeout: 120s \ No newline at end of file