mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 04:27:39 +00:00
(improvement)(headless) Add support for the Ollama provider in the frontend and optimize the code (#1270)
This commit is contained in:
@@ -7,7 +7,6 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryReposi
|
|||||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
@@ -16,6 +15,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
|||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
@@ -46,7 +46,7 @@ public class PlainTextExecutor implements ChatExecutor {
|
|||||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||||
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
||||||
|
|
||||||
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig());
|
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(chatAgent.getLlmConfig());
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
|
|
||||||
QueryResult result = new QueryResult();
|
QueryResult result = new QueryResult();
|
||||||
|
|||||||
@@ -4,10 +4,10 @@ import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
|||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
|
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
@@ -54,7 +54,7 @@ public class MemoryReviewTask {
|
|||||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||||
|
|
||||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
|
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
|
||||||
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig());
|
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(chatAgent.getLlmConfig());
|
||||||
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
||||||
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);
|
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
|||||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
import com.tencent.supersonic.common.config.LLMConfig;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
@@ -24,6 +23,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
|
|||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
import dev.langchain4j.model.output.Response;
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -180,7 +180,7 @@ public class NL2SQLParser implements ChatParser {
|
|||||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
|
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
|
||||||
|
|
||||||
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig());
|
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(context.getLlmConfig());
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
|
|
||||||
String result = response.content().text();
|
String result = response.content().text();
|
||||||
|
|||||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.server.util;
|
|||||||
|
|
||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
import com.tencent.supersonic.common.config.LLMConfig;
|
||||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||||
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
@@ -14,7 +14,7 @@ public class LLMConnHelper {
|
|||||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
if (llmConfig == null || StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(llmConfig);
|
ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(llmConfig);
|
||||||
String response = chatLanguageModel.generate("Hi there");
|
String response = chatLanguageModel.generate("Hi there");
|
||||||
return StringUtils.isNotEmpty(response) ? true : false;
|
return StringUtils.isNotEmpty(response) ? true : false;
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
package com.tencent.supersonic.common.pojo.enums;
|
|
||||||
|
|
||||||
public enum S2ModelProvider {
|
|
||||||
|
|
||||||
OPEN_AI,
|
|
||||||
HUGGING_FACE,
|
|
||||||
LOCAL_AI,
|
|
||||||
IN_PROCESS
|
|
||||||
}
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
package com.tencent.supersonic.common.util;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
|
||||||
import dev.langchain4j.model.localai.LocalAiChatModel;
|
|
||||||
import dev.langchain4j.model.openai.OpenAiChatModel;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
|
|
||||||
import java.time.Duration;
|
|
||||||
|
|
||||||
public class S2ChatModelProvider {
|
|
||||||
|
|
||||||
public static ChatLanguageModel provide(LLMConfig llmConfig) {
|
|
||||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
|
||||||
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|
|
||||||
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
|
||||||
return chatLanguageModel;
|
|
||||||
}
|
|
||||||
if (S2ModelProvider.OPEN_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
|
|
||||||
return OpenAiChatModel
|
|
||||||
.builder()
|
|
||||||
.baseUrl(llmConfig.getBaseUrl())
|
|
||||||
.modelName(llmConfig.getModelName())
|
|
||||||
.apiKey(llmConfig.keyDecrypt())
|
|
||||||
.temperature(llmConfig.getTemperature())
|
|
||||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
|
||||||
.build();
|
|
||||||
} else if (S2ModelProvider.LOCAL_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
|
|
||||||
return LocalAiChatModel
|
|
||||||
.builder()
|
|
||||||
.baseUrl(llmConfig.getBaseUrl())
|
|
||||||
.modelName(llmConfig.getModelName())
|
|
||||||
.temperature(llmConfig.getTemperature())
|
|
||||||
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
throw new RuntimeException("unsupported provider: " + llmConfig.getProvider());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package dev.langchain4j.model.provider;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.config.LLMConfig;
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
|
||||||
|
public interface ChatLanguageModelFactory {
|
||||||
|
ChatLanguageModel create(LLMConfig llmConfig);
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package dev.langchain4j.model.provider;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.config.LLMConfig;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class ChatLanguageModelProvider {
|
||||||
|
private static final Map<String, ChatLanguageModelFactory> factories = new HashMap<>();
|
||||||
|
|
||||||
|
static {
|
||||||
|
factories.put(ModelProvider.OPEN_AI.name(), new OpenAiChatModelFactory());
|
||||||
|
factories.put(ModelProvider.LOCAL_AI.name(), new LocalAiChatModelFactory());
|
||||||
|
factories.put(ModelProvider.OLLAMA.name(), new OllamaChatModelFactory());
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ChatLanguageModel provide(LLMConfig llmConfig) {
|
||||||
|
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|
||||||
|
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
|
||||||
|
return ContextUtils.getBean(ChatLanguageModel.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatLanguageModelFactory factory = factories.get(llmConfig.getProvider().toUpperCase());
|
||||||
|
if (factory != null) {
|
||||||
|
return factory.create(llmConfig);
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new RuntimeException("Unsupported provider: " + llmConfig.getProvider());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package dev.langchain4j.model.provider;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.config.LLMConfig;
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.localai.LocalAiChatModel;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
|
public class LocalAiChatModelFactory implements ChatLanguageModelFactory {
|
||||||
|
@Override
|
||||||
|
public ChatLanguageModel create(LLMConfig llmConfig) {
|
||||||
|
return LocalAiChatModel
|
||||||
|
.builder()
|
||||||
|
.baseUrl(llmConfig.getBaseUrl())
|
||||||
|
.modelName(llmConfig.getModelName())
|
||||||
|
.temperature(llmConfig.getTemperature())
|
||||||
|
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package dev.langchain4j.model.provider;
|
||||||
|
|
||||||
|
public enum ModelProvider {
|
||||||
|
OPEN_AI,
|
||||||
|
HUGGING_FACE,
|
||||||
|
LOCAL_AI,
|
||||||
|
IN_PROCESS,
|
||||||
|
OLLAMA
|
||||||
|
}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
package dev.langchain4j.model.provider;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.config.LLMConfig;
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.ollama.OllamaChatModel;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
|
public class OllamaChatModelFactory implements ChatLanguageModelFactory {
|
||||||
|
@Override
|
||||||
|
public ChatLanguageModel create(LLMConfig llmConfig) {
|
||||||
|
return OllamaChatModel
|
||||||
|
.builder()
|
||||||
|
.baseUrl(llmConfig.getBaseUrl())
|
||||||
|
.modelName(llmConfig.getModelName())
|
||||||
|
.temperature(llmConfig.getTemperature())
|
||||||
|
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
package dev.langchain4j.model.provider;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.config.LLMConfig;
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.openai.OpenAiChatModel;
|
||||||
|
|
||||||
|
import java.time.Duration;
|
||||||
|
|
||||||
|
public class OpenAiChatModelFactory implements ChatLanguageModelFactory {
|
||||||
|
@Override
|
||||||
|
public ChatLanguageModel create(LLMConfig llmConfig) {
|
||||||
|
return OpenAiChatModel
|
||||||
|
.builder()
|
||||||
|
.baseUrl(llmConfig.getBaseUrl())
|
||||||
|
.modelName(llmConfig.getModelName())
|
||||||
|
.apiKey(llmConfig.keyDecrypt())
|
||||||
|
.temperature(llmConfig.getTemperature())
|
||||||
|
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,8 +3,8 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
|||||||
import com.tencent.supersonic.common.config.LLMConfig;
|
import com.tencent.supersonic.common.config.LLMConfig;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||||
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.beans.factory.InitializingBean;
|
import org.springframework.beans.factory.InitializingBean;
|
||||||
@@ -24,7 +24,7 @@ public abstract class SqlGenStrategy implements InitializingBean {
|
|||||||
protected PromptHelper promptHelper;
|
protected PromptHelper promptHelper;
|
||||||
|
|
||||||
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
|
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
|
||||||
return S2ChatModelProvider.provide(llmConfig);
|
return ChatLanguageModelProvider.provide(llmConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract LLMResp generate(LLMReq llmReq);
|
abstract LLMResp generate(LLMReq llmReq);
|
||||||
|
|||||||
@@ -13,4 +13,25 @@ langchain4j:
|
|||||||
embedding-model:
|
embedding-model:
|
||||||
model-name: bge-small-zh
|
model-name: bge-small-zh
|
||||||
embedding-store:
|
embedding-store:
|
||||||
persist-path: /tmp
|
persist-path: /tmp
|
||||||
|
# ollama:
|
||||||
|
# chat-model:
|
||||||
|
# base-url: http://localhost:11434
|
||||||
|
# api-key: demo
|
||||||
|
# model-name: qwen:0.5b
|
||||||
|
# temperature: 0.0
|
||||||
|
# timeout: PT60S
|
||||||
|
|
||||||
|
# chroma:
|
||||||
|
# embedding-store:
|
||||||
|
# baseUrl: http://0.0.0.0:8000
|
||||||
|
# timeout: 120s
|
||||||
|
|
||||||
|
# milvus:
|
||||||
|
# embedding-store:
|
||||||
|
# host: localhost
|
||||||
|
# port: 2379
|
||||||
|
# uri: http://0.0.0.0:19530
|
||||||
|
# token: demo
|
||||||
|
# dimension: 512
|
||||||
|
# timeout: 120s
|
||||||
@@ -13,4 +13,25 @@ langchain4j:
|
|||||||
embedding-model:
|
embedding-model:
|
||||||
model-name: bge-small-zh
|
model-name: bge-small-zh
|
||||||
embedding-store:
|
embedding-store:
|
||||||
persist-path: /tmp
|
persist-path: /tmp
|
||||||
|
# ollama:
|
||||||
|
# chat-model:
|
||||||
|
# base-url: http://localhost:11434
|
||||||
|
# api-key: demo
|
||||||
|
# model-name: qwen:0.5b
|
||||||
|
# temperature: 0.0
|
||||||
|
# timeout: PT60S
|
||||||
|
|
||||||
|
# chroma:
|
||||||
|
# embedding-store:
|
||||||
|
# baseUrl: http://0.0.0.0:8000
|
||||||
|
# timeout: 120s
|
||||||
|
|
||||||
|
# milvus:
|
||||||
|
# embedding-store:
|
||||||
|
# host: localhost
|
||||||
|
# port: 2379
|
||||||
|
# uri: http://0.0.0.0:19530
|
||||||
|
# token: demo
|
||||||
|
# dimension: 512
|
||||||
|
# timeout: 120s
|
||||||
Reference in New Issue
Block a user