(improvement)(headless) Add support for the Ollama provider in the frontend and optimize the code (#1270)

This commit is contained in:
lexluo09
2024-06-28 17:29:58 +08:00
committed by GitHub
parent 7564256b0a
commit 528491717b
15 changed files with 165 additions and 62 deletions

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryReposi
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.S2ChatModelProvider;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.QueryState;
@@ -16,6 +15,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@@ -46,7 +46,7 @@ public class PlainTextExecutor implements ChatExecutor {
AgentService agentService = ContextUtils.getBean(AgentService.class); AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId()); Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig()); ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(chatAgent.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
QueryResult result = new QueryResult(); QueryResult result = new QueryResult();

View File

@@ -4,10 +4,10 @@ import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.util.S2ChatModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@@ -54,7 +54,7 @@ public class MemoryReviewTask {
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr); keyPipelineLog.info("MemoryReviewTask reqPrompt:{}", promptStr);
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig()); ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(chatAgent.getLlmConfig());
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text(); String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
keyPipelineLog.info("MemoryReviewTask modelResp:{}", response); keyPipelineLog.info("MemoryReviewTask modelResp:{}", response);

View File

@@ -8,7 +8,6 @@ import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.S2ChatModelProvider;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
@@ -24,6 +23,7 @@ import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -180,7 +180,7 @@ public class NL2SQLParser implements ChatParser {
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr); keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(context.getLlmConfig()); ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(context.getLlmConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String result = response.content().text(); String result = response.content().text();

View File

@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.common.util.S2ChatModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
@@ -14,7 +14,7 @@ public class LLMConnHelper {
if (llmConfig == null || StringUtils.isBlank(llmConfig.getBaseUrl())) { if (llmConfig == null || StringUtils.isBlank(llmConfig.getBaseUrl())) {
return false; return false;
} }
ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(llmConfig); ChatLanguageModel chatLanguageModel = ChatLanguageModelProvider.provide(llmConfig);
String response = chatLanguageModel.generate("Hi there"); String response = chatLanguageModel.generate("Hi there");
return StringUtils.isNotEmpty(response) ? true : false; return StringUtils.isNotEmpty(response) ? true : false;
} catch (Exception e) { } catch (Exception e) {

View File

@@ -1,9 +0,0 @@
package com.tencent.supersonic.common.pojo.enums;
public enum S2ModelProvider {
OPEN_AI,
HUGGING_FACE,
LOCAL_AI,
IN_PROCESS
}

View File

@@ -1,41 +0,0 @@
package com.tencent.supersonic.common.util;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.pojo.enums.S2ModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import org.apache.commons.lang3.StringUtils;
import java.time.Duration;
public class S2ChatModelProvider {
public static ChatLanguageModel provide(LLMConfig llmConfig) {
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
return chatLanguageModel;
}
if (S2ModelProvider.OPEN_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
return OpenAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.apiKey(llmConfig.keyDecrypt())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
} else if (S2ModelProvider.LOCAL_AI.name().equalsIgnoreCase(llmConfig.getProvider())) {
return LocalAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
throw new RuntimeException("unsupported provider: " + llmConfig.getProvider());
}
}

View File

@@ -0,0 +1,8 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
public interface ChatLanguageModelFactory {
ChatLanguageModel create(LLMConfig llmConfig);
}

View File

@@ -0,0 +1,33 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.model.chat.ChatLanguageModel;
import org.apache.commons.lang3.StringUtils;
import java.util.HashMap;
import java.util.Map;
public class ChatLanguageModelProvider {
private static final Map<String, ChatLanguageModelFactory> factories = new HashMap<>();
static {
factories.put(ModelProvider.OPEN_AI.name(), new OpenAiChatModelFactory());
factories.put(ModelProvider.LOCAL_AI.name(), new LocalAiChatModelFactory());
factories.put(ModelProvider.OLLAMA.name(), new OllamaChatModelFactory());
}
public static ChatLanguageModel provide(LLMConfig llmConfig) {
if (llmConfig == null || StringUtils.isBlank(llmConfig.getProvider())
|| StringUtils.isBlank(llmConfig.getBaseUrl())) {
return ContextUtils.getBean(ChatLanguageModel.class);
}
ChatLanguageModelFactory factory = factories.get(llmConfig.getProvider().toUpperCase());
if (factory != null) {
return factory.create(llmConfig);
}
throw new RuntimeException("Unsupported provider: " + llmConfig.getProvider());
}
}

View File

@@ -0,0 +1,20 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.localai.LocalAiChatModel;
import java.time.Duration;
public class LocalAiChatModelFactory implements ChatLanguageModelFactory {
@Override
public ChatLanguageModel create(LLMConfig llmConfig) {
return LocalAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
}

View File

@@ -0,0 +1,9 @@
package dev.langchain4j.model.provider;
public enum ModelProvider {
OPEN_AI,
HUGGING_FACE,
LOCAL_AI,
IN_PROCESS,
OLLAMA
}

View File

@@ -0,0 +1,20 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.ollama.OllamaChatModel;
import java.time.Duration;
public class OllamaChatModelFactory implements ChatLanguageModelFactory {
@Override
public ChatLanguageModel create(LLMConfig llmConfig) {
return OllamaChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
}

View File

@@ -0,0 +1,21 @@
package dev.langchain4j.model.provider;
import com.tencent.supersonic.common.config.LLMConfig;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.openai.OpenAiChatModel;
import java.time.Duration;
public class OpenAiChatModelFactory implements ChatLanguageModelFactory {
@Override
public ChatLanguageModel create(LLMConfig llmConfig) {
return OpenAiChatModel
.builder()
.baseUrl(llmConfig.getBaseUrl())
.modelName(llmConfig.getModelName())
.apiKey(llmConfig.keyDecrypt())
.temperature(llmConfig.getTemperature())
.timeout(Duration.ofSeconds(llmConfig.getTimeOut()))
.build();
}
}

View File

@@ -3,8 +3,8 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.config.LLMConfig; import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.common.util.S2ChatModelProvider;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.provider.ChatLanguageModelProvider;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
@@ -24,7 +24,7 @@ public abstract class SqlGenStrategy implements InitializingBean {
protected PromptHelper promptHelper; protected PromptHelper promptHelper;
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) { protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
return S2ChatModelProvider.provide(llmConfig); return ChatLanguageModelProvider.provide(llmConfig);
} }
abstract LLMResp generate(LLMReq llmReq); abstract LLMResp generate(LLMReq llmReq);

View File

@@ -13,4 +13,25 @@ langchain4j:
embedding-model: embedding-model:
model-name: bge-small-zh model-name: bge-small-zh
embedding-store: embedding-store:
persist-path: /tmp persist-path: /tmp
# ollama:
# chat-model:
# base-url: http://localhost:11434
# api-key: demo
# model-name: qwen:0.5b
# temperature: 0.0
# timeout: PT60S
# chroma:
# embedding-store:
# baseUrl: http://0.0.0.0:8000
# timeout: 120s
# milvus:
# embedding-store:
# host: localhost
# port: 2379
# uri: http://0.0.0.0:19530
# token: demo
# dimension: 512
# timeout: 120s

View File

@@ -13,4 +13,25 @@ langchain4j:
embedding-model: embedding-model:
model-name: bge-small-zh model-name: bge-small-zh
embedding-store: embedding-store:
persist-path: /tmp persist-path: /tmp
# ollama:
# chat-model:
# base-url: http://localhost:11434
# api-key: demo
# model-name: qwen:0.5b
# temperature: 0.0
# timeout: PT60S
# chroma:
# embedding-store:
# baseUrl: http://0.0.0.0:8000
# timeout: 120s
# milvus:
# embedding-store:
# host: localhost
# port: 2379
# uri: http://0.0.0.0:19530
# token: demo
# dimension: 512
# timeout: 120s