From d06439325357ad8f9a6b3d46a62ffcab1d304978 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Wed, 16 Oct 2024 13:08:22 +0800 Subject: [PATCH] [fix][chat]Fix NPE problem. --- .../chat/server/executor/PlainTextExecutor.java | 4 ++-- .../chat/server/memory/MemoryReviewTask.java | 2 +- .../chat/server/parser/NL2SQLParser.java | 4 ++-- .../service/impl/ChatQueryServiceImpl.java | 9 +++++++-- .../chat/server/util/ModelConfigHelper.java | 12 +++++++++--- .../headless/chat/corrector/LLMSqlCorrector.java | 4 +++- .../tencent/supersonic/demo/SmallTalkDemo.java | 16 ++++++++++++++++ 7 files changed, 40 insertions(+), 11 deletions(-) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java index 74d7f4bbc..ae2204e9f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -25,7 +25,7 @@ import java.util.stream.Collectors; public class PlainTextExecutor implements ChatQueryExecutor { - private static final String APP_KEY = "SMALL_TALK"; + public static final String APP_KEY = "SMALL_TALK"; private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to." + "\n#Task: Respond quickly and nicely to the user." + "\n#Rules: 1.ALWAYS use the same language as the `#Current Input`." @@ -45,7 +45,7 @@ public class PlainTextExecutor implements ChatQueryExecutor { AgentService agentService = ContextUtils.getBean(AgentService.class); Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId()); ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY); - if (!chatApp.isEnable()) { + if (Objects.isNull(chatApp) || !chatApp.isEnable()) { return null; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java index 84564246b..94d1e22ec 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -75,7 +75,7 @@ public class MemoryReviewTask { } ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY); - if (!chatApp.isEnable()) { + if (Objects.isNull(chatApp) || !chatApp.isEnable()) { return; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 2e409824f..0406098d5 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -173,7 +173,7 @@ public class NL2SQLParser implements ChatQueryParser { private void processMultiTurn(ParseContext parseContext) { ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_MULTI_TURN); - if (!chatApp.isEnable()) { + if (Objects.isNull(chatApp) || !chatApp.isEnable()) { return; } @@ -222,7 +222,7 @@ public class NL2SQLParser implements ChatQueryParser { List similarExemplars) { ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE); - if (!chatApp.isEnable()) { + if (Objects.isNull(chatApp) || !chatApp.isEnable()) { return errMsg; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index 5edb1b241..791a4b072 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor; import com.tencent.supersonic.chat.server.parser.ChatQueryParser; +import com.tencent.supersonic.chat.server.pojo.ChatModel; import com.tencent.supersonic.chat.server.pojo.ExecuteContext; import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor; @@ -171,8 +172,12 @@ public class ChatQueryServiceImpl implements ChatQueryService { ParseContext parseContext = new ParseContext(); BeanMapper.mapper(chatParseReq, parseContext); Agent agent = agentService.getAgent(chatParseReq.getAgentId()); - agent.getChatAppConfig().values().forEach(c -> c - .setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig())); + agent.getChatAppConfig().values().forEach(c -> { + ChatModel chatModel = chatModelService.getChatModel(c.getChatModelId()); + if (Objects.nonNull(chatModel)) { + c.setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig()); + } + }); parseContext.setAgent(agent); return parseContext; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ModelConfigHelper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ModelConfigHelper.java index 250726e29..899d020b6 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ModelConfigHelper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ModelConfigHelper.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.chat.server.util; import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.chat.server.pojo.ChatModel; import com.tencent.supersonic.chat.server.service.ChatModelService; import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.ChatModelConfig; @@ -12,6 +13,8 @@ import dev.langchain4j.provider.ModelProvider; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +import java.util.Objects; + @Slf4j public class ModelConfigHelper { public static boolean testConnection(ChatModelConfig modelConfig) { @@ -30,8 +33,11 @@ public class ModelConfigHelper { public static ChatModelConfig getChatModelConfig(ChatApp chatApp) { ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class); - ChatModelConfig chatModelConfig = - chatModelService.getChatModel(chatApp.getChatModelId()).getConfig(); - return chatModelConfig; + ChatModel chatModel = chatModelService.getChatModel(chatApp.getChatModelId()); + if (Objects.isNull(chatModel)) { + return null; + } + + return chatModel.getConfig(); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java index b961f243a..6e16827cd 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/LLMSqlCorrector.java @@ -20,6 +20,7 @@ import org.slf4j.LoggerFactory; import java.util.HashMap; import java.util.Map; +import java.util.Objects; @Slf4j public class LLMSqlCorrector extends BaseSemanticCorrector { @@ -61,7 +62,8 @@ public class LLMSqlCorrector extends BaseSemanticCorrector { @Override public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { ChatApp chatApp = chatQueryContext.getChatAppConfig().get(APP_KEY); - if (!chatQueryContext.getText2SQLType().enableLLM() || !chatApp.isEnable()) { + if (!chatQueryContext.getText2SQLType().enableLLM() || Objects.isNull(chatApp) + || !chatApp.isEnable()) { return; } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/SmallTalkDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/SmallTalkDemo.java index 989cfec04..82f2d4c16 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/SmallTalkDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/SmallTalkDemo.java @@ -2,13 +2,21 @@ package com.tencent.supersonic.demo; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; +import com.google.common.collect.Maps; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.ToolConfig; +import com.tencent.supersonic.chat.server.executor.PlainTextExecutor; +import com.tencent.supersonic.chat.server.parser.PlainTextParser; +import com.tencent.supersonic.common.pojo.ChatApp; +import com.tencent.supersonic.common.pojo.enums.AppModule; +import com.tencent.supersonic.common.util.ChatAppManager; +import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy; import lombok.extern.slf4j.Slf4j; import org.springframework.core.annotation.Order; import org.springframework.stereotype.Component; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; @Component @@ -25,6 +33,14 @@ public class SmallTalkDemo extends S2BaseDemo { ToolConfig toolConfig = new ToolConfig(); agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agent.setExamples(Lists.newArrayList("如何才能变帅", "如何才能赚更多钱", "如何才能世界和平")); + + // configure chat apps + Map chatAppConfig = + Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT)); + chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId())); + chatAppConfig.get(PlainTextExecutor.APP_KEY).setEnable(true); + chatAppConfig.get(OnePassSCSqlGenStrategy.APP_KEY).setEnable(false); + agent.setChatAppConfig(chatAppConfig); agentService.createAgent(agent, defaultUser); }