From b34268c23685fb8d7abbdc2a6b68ca52daa29d52 Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Fri, 28 Jun 2024 15:51:31 +0800 Subject: [PATCH] (improvement)(Chat) Add encryption and decryption for llm api key (#1267) * (improvement)(Chat) Add encryption and decryption for llm api key * (improvement)(Chat) Change Plugin to ChatPlugin --------- Co-authored-by: lxwcodemonkey --- .../plugin/{Plugin.java => ChatPlugin.java} | 2 +- .../chat/server/plugin/PluginManager.java | 26 +++--- .../chat/server/plugin/PluginParseResult.java | 2 +- .../server/plugin/PluginRecallResult.java | 2 +- .../plugin/build/webpage/WebPageQuery.java | 4 +- .../build/webservice/WebServiceQuery.java | 4 +- .../server/plugin/event/PluginAddEvent.java | 8 +- .../server/plugin/event/PluginDelEvent.java | 8 +- .../plugin/event/PluginUpdateEvent.java | 12 +-- .../plugin/recognize/PluginRecognizer.java | 8 +- .../embedding/EmbeddingRecallRecognizer.java | 10 +-- .../chat/server/rest/PluginController.java | 14 ++-- .../chat/server/service/PluginService.java | 18 ++-- .../service/impl/PluginServiceImpl.java | 40 ++++----- .../chat/server/util/LLMConnHelper.java | 9 +- .../supersonic/common/config/LLMConfig.java | 6 ++ .../supersonic/common/util/AESUtil.java | 84 +++++++++++++++++++ .../common/util/S2ChatModelProvider.java | 2 +- .../tencent/supersonic/demo/S2VisitsDemo.java | 4 +- 19 files changed, 180 insertions(+), 83 deletions(-) rename chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/{Plugin.java => ChatPlugin.java} (96%) create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/AESUtil.java diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/Plugin.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/ChatPlugin.java similarity index 96% rename from chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/Plugin.java rename to chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/ChatPlugin.java index 85f1234fe..42b0a3938 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/Plugin.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/ChatPlugin.java @@ -10,7 +10,7 @@ import org.apache.commons.lang3.StringUtils; import java.util.List; @Data -public class Plugin extends RecordInfo { +public class ChatPlugin extends RecordInfo { private Long id; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java index a8c8b9734..d14904200 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java @@ -52,10 +52,10 @@ public class PluginManager { @Autowired private EmbeddingService embeddingService; - public static List getPluginAgentCanSupport(ChatParseContext chatParseContext) { + public static List getPluginAgentCanSupport(ChatParseContext chatParseContext) { PluginService pluginService = ContextUtils.getBean(PluginService.class); Agent agent = chatParseContext.getAgent(); - List plugins = pluginService.getPluginList(); + List plugins = pluginService.getPluginList(); if (Objects.isNull(agent)) { return plugins; } @@ -67,7 +67,7 @@ public class PluginManager { plugins = plugins.stream().filter(plugin -> pluginIds.contains(plugin.getId())) .collect(Collectors.toList()); log.info("plugins witch can be supported by cur agent :{} {}", agent.getName(), - plugins.stream().map(Plugin::getName).collect(Collectors.toList())); + plugins.stream().map(ChatPlugin::getName).collect(Collectors.toList())); return plugins; } @@ -85,7 +85,7 @@ public class PluginManager { @EventListener public void addPlugin(PluginAddEvent pluginAddEvent) { - Plugin plugin = pluginAddEvent.getPlugin(); + ChatPlugin plugin = pluginAddEvent.getPlugin(); if (CollectionUtils.isNotEmpty(plugin.getExampleQuestionList())) { requestEmbeddingPluginAdd(convert(Lists.newArrayList(plugin))); } @@ -93,8 +93,8 @@ public class PluginManager { @EventListener public void updatePlugin(PluginUpdateEvent pluginUpdateEvent) { - Plugin oldPlugin = pluginUpdateEvent.getOldPlugin(); - Plugin newPlugin = pluginUpdateEvent.getNewPlugin(); + ChatPlugin oldPlugin = pluginUpdateEvent.getOldPlugin(); + ChatPlugin newPlugin = pluginUpdateEvent.getNewPlugin(); if (CollectionUtils.isNotEmpty(oldPlugin.getExampleQuestionList())) { requestEmbeddingPluginDelete(getEmbeddingId(Lists.newArrayList(oldPlugin))); } @@ -105,7 +105,7 @@ public class PluginManager { @EventListener public void delPlugin(PluginDelEvent pluginDelEvent) { - Plugin plugin = pluginDelEvent.getPlugin(); + ChatPlugin plugin = pluginDelEvent.getPlugin(); if (CollectionUtils.isNotEmpty(plugin.getExampleQuestionList())) { requestEmbeddingPluginDelete(getEmbeddingId(Lists.newArrayList(plugin))); } @@ -155,9 +155,9 @@ public class PluginManager { throw new RuntimeException("get embedding result failed"); } - public List convert(List plugins) { + public List convert(List plugins) { List queries = Lists.newArrayList(); - for (Plugin plugin : plugins) { + for (ChatPlugin plugin : plugins) { List exampleQuestions = plugin.getExampleQuestionList(); int num = 0; for (String pattern : exampleQuestions) { @@ -170,7 +170,7 @@ public class PluginManager { return queries; } - private Set getEmbeddingId(List plugins) { + private Set getEmbeddingId(List plugins) { Set embeddingIdSet = new HashSet<>(); for (TextSegment query : convert(plugins)) { TextSegmentConvert.addQueryId(query, TextSegmentConvert.getQueryId(query)); @@ -191,7 +191,7 @@ public class PluginManager { return String.valueOf(Integer.parseInt(id) / 1000); } - public static Pair> resolve(Plugin plugin, ChatParseContext chatParseContext) { + public static Pair> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) { SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo(); Set pluginMatchedModel = getPluginMatchedModel(plugin, chatParseContext); if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) { @@ -245,7 +245,7 @@ public class PluginManager { .collect(Collectors.toSet()); } - private static List getSemanticOption(Plugin plugin) { + private static List getSemanticOption(ChatPlugin plugin) { WebBase webBase = JSONObject.parseObject(plugin.getConfig(), WebBase.class); if (Objects.isNull(webBase)) { return null; @@ -259,7 +259,7 @@ public class PluginManager { .collect(Collectors.toList()); } - private static Set getPluginMatchedModel(Plugin plugin, ChatParseContext chatParseContext) { + private static Set getPluginMatchedModel(ChatPlugin plugin, ChatParseContext chatParseContext) { Set matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos(); if (plugin.isContainsAllModel()) { return Sets.newHashSet(plugin.getDefaultMode()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginParseResult.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginParseResult.java index 8cb7b2170..7930136be 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginParseResult.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginParseResult.java @@ -6,7 +6,7 @@ import lombok.Data; @Data public class PluginParseResult { - private Plugin plugin; + private ChatPlugin plugin; private QueryFilters queryFilters; private double distance; private String queryText; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginRecallResult.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginRecallResult.java index f44402d80..906f71d92 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginRecallResult.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginRecallResult.java @@ -13,7 +13,7 @@ import java.util.Set; @NoArgsConstructor public class PluginRecallResult { - private Plugin plugin; + private ChatPlugin plugin; private Set dataSetIds; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webpage/WebPageQuery.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webpage/WebPageQuery.java index 9597ca547..04e3bb346 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webpage/WebPageQuery.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webpage/WebPageQuery.java @@ -1,6 +1,6 @@ package com.tencent.supersonic.chat.server.plugin.build.webpage; -import com.tencent.supersonic.chat.server.plugin.Plugin; +import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import com.tencent.supersonic.chat.server.plugin.PluginParseResult; import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery; @@ -25,7 +25,7 @@ public class WebPageQuery extends PluginSemanticQuery { } protected WebPageResp buildResponse(PluginParseResult pluginParseResult) { - Plugin plugin = pluginParseResult.getPlugin(); + ChatPlugin plugin = pluginParseResult.getPlugin(); WebPageResp webPageResponse = new WebPageResp(); webPageResponse.setName(plugin.getName()); webPageResponse.setPluginId(plugin.getId()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webservice/WebServiceQuery.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webservice/WebServiceQuery.java index 688fed85b..16b761ba3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webservice/WebServiceQuery.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webservice/WebServiceQuery.java @@ -1,7 +1,7 @@ package com.tencent.supersonic.chat.server.plugin.build.webservice; import com.alibaba.fastjson.JSON; -import com.tencent.supersonic.chat.server.plugin.Plugin; +import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import com.tencent.supersonic.chat.server.plugin.PluginParseResult; import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.plugin.build.ParamOption; @@ -71,7 +71,7 @@ public class WebServiceQuery extends PluginSemanticQuery { protected WebServiceResp buildResponse(PluginParseResult pluginParseResult) { WebServiceResp webServiceResponse = new WebServiceResp(); - Plugin plugin = pluginParseResult.getPlugin(); + ChatPlugin plugin = pluginParseResult.getPlugin(); WebBase webBase = fillWebBaseResult(JsonUtil.toObject(plugin.getConfig(), WebBase.class), pluginParseResult); webServiceResponse.setWebBase(webBase); List paramOptions = webBase.getParamOptions(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/event/PluginAddEvent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/event/PluginAddEvent.java index 41d493994..446085193 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/event/PluginAddEvent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/event/PluginAddEvent.java @@ -1,18 +1,18 @@ package com.tencent.supersonic.chat.server.plugin.event; -import com.tencent.supersonic.chat.server.plugin.Plugin; +import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import org.springframework.context.ApplicationEvent; public class PluginAddEvent extends ApplicationEvent { - private Plugin plugin; + private ChatPlugin plugin; - public PluginAddEvent(Object source, Plugin plugin) { + public PluginAddEvent(Object source, ChatPlugin plugin) { super(source); this.plugin = plugin; } - public Plugin getPlugin() { + public ChatPlugin getPlugin() { return plugin; } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/event/PluginDelEvent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/event/PluginDelEvent.java index fb16cb7a4..e68c961d7 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/event/PluginDelEvent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/event/PluginDelEvent.java @@ -1,19 +1,19 @@ package com.tencent.supersonic.chat.server.plugin.event; -import com.tencent.supersonic.chat.server.plugin.Plugin; +import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import org.springframework.context.ApplicationEvent; public class PluginDelEvent extends ApplicationEvent { - private Plugin plugin; + private ChatPlugin plugin; - public PluginDelEvent(Object source, Plugin plugin) { + public PluginDelEvent(Object source, ChatPlugin plugin) { super(source); this.plugin = plugin; } - public Plugin getPlugin() { + public ChatPlugin getPlugin() { return plugin; } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/event/PluginUpdateEvent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/event/PluginUpdateEvent.java index dda1174ee..c2bc6b785 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/event/PluginUpdateEvent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/event/PluginUpdateEvent.java @@ -1,25 +1,25 @@ package com.tencent.supersonic.chat.server.plugin.event; -import com.tencent.supersonic.chat.server.plugin.Plugin; +import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import org.springframework.context.ApplicationEvent; public class PluginUpdateEvent extends ApplicationEvent { - private Plugin oldPlugin; + private ChatPlugin oldPlugin; - private Plugin newPlugin; + private ChatPlugin newPlugin; - public PluginUpdateEvent(Object source, Plugin oldPlugin, Plugin newPlugin) { + public PluginUpdateEvent(Object source, ChatPlugin oldPlugin, ChatPlugin newPlugin) { super(source); this.oldPlugin = oldPlugin; this.newPlugin = newPlugin; } - public Plugin getOldPlugin() { + public ChatPlugin getOldPlugin() { return oldPlugin; } - public Plugin getNewPlugin() { + public ChatPlugin getNewPlugin() { return newPlugin; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java index 2a18adae0..dacbe9d40 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.server.plugin.recognize; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import com.tencent.supersonic.chat.server.plugin.Plugin; +import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import com.tencent.supersonic.chat.server.plugin.PluginManager; import com.tencent.supersonic.chat.server.plugin.PluginParseResult; import com.tencent.supersonic.chat.server.plugin.PluginRecallResult; @@ -45,7 +45,7 @@ public abstract class PluginRecognizer { public void buildQuery(ChatParseContext chatParseContext, ParseResp parseResp, PluginRecallResult pluginRecallResult) { - Plugin plugin = pluginRecallResult.getPlugin(); + ChatPlugin plugin = pluginRecallResult.getPlugin(); Set dataSetIds = pluginRecallResult.getDataSetIds(); if (plugin.isContainsAllModel()) { dataSetIds = Sets.newHashSet(-1L); @@ -59,11 +59,11 @@ public abstract class PluginRecognizer { } } - protected List getPluginList(ChatParseContext chatParseContext) { + protected List getPluginList(ChatParseContext chatParseContext) { return PluginManager.getPluginAgentCanSupport(chatParseContext); } - protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, Plugin plugin, + protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin, ChatParseContext chatParseContext, double distance) { List schemaElementMatches = chatParseContext.getMapInfo().getMatchedElements(dataSetId); QueryFilters queryFilters = chatParseContext.getQueryFilters(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java index a1b31d890..286888f29 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java @@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.server.plugin.recognize.embedding; import com.google.common.collect.Lists; import com.tencent.supersonic.chat.server.plugin.ParseMode; -import com.tencent.supersonic.chat.server.plugin.Plugin; +import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import com.tencent.supersonic.chat.server.plugin.PluginManager; import com.tencent.supersonic.chat.server.plugin.PluginRecallResult; import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer; @@ -27,7 +27,7 @@ import java.util.stream.Collectors; public class EmbeddingRecallRecognizer extends PluginRecognizer { public boolean checkPreCondition(ChatParseContext chatParseContext) { - List plugins = getPluginList(chatParseContext); + List plugins = getPluginList(chatParseContext); return !CollectionUtils.isEmpty(plugins); } @@ -37,10 +37,10 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer { if (CollectionUtils.isEmpty(embeddingRetrievals)) { return null; } - List plugins = getPluginList(chatParseContext); - Map pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p)); + List plugins = getPluginList(chatParseContext); + Map pluginMap = plugins.stream().collect(Collectors.toMap(ChatPlugin::getId, p -> p)); for (Retrieval embeddingRetrieval : embeddingRetrievals) { - Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId())); + ChatPlugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId())); if (plugin == null) { continue; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/PluginController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/PluginController.java index d12e1db81..51ff49426 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/PluginController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/PluginController.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.server.rest; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq; -import com.tencent.supersonic.chat.server.plugin.Plugin; +import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import com.tencent.supersonic.chat.server.service.PluginService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.RestController; @@ -25,7 +25,7 @@ public class PluginController { protected PluginService pluginService; @PostMapping - public boolean createPlugin(@RequestBody Plugin plugin, + public boolean createPlugin(@RequestBody ChatPlugin plugin, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); @@ -34,7 +34,7 @@ public class PluginController { } @PutMapping - public boolean updatePlugin(@RequestBody Plugin plugin, + public boolean updatePlugin(@RequestBody ChatPlugin plugin, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); @@ -49,14 +49,14 @@ public class PluginController { } @RequestMapping("/getPluginList") - public List getPluginList() { + public List getPluginList() { return pluginService.getPluginList(); } @PostMapping("/query") - List query(@RequestBody PluginQueryReq pluginQueryReq, - HttpServletRequest httpServletRequest, - HttpServletResponse httpServletResponse) { + List query(@RequestBody PluginQueryReq pluginQueryReq, + HttpServletRequest httpServletRequest, + HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); return pluginService.queryWithAuthCheck(pluginQueryReq, user); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/PluginService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/PluginService.java index 0bdd2361e..9aa803404 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/PluginService.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/PluginService.java @@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.server.service; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq; -import com.tencent.supersonic.chat.server.plugin.Plugin; +import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import java.util.List; import java.util.Map; @@ -11,22 +11,22 @@ import java.util.Optional; public interface PluginService { - void createPlugin(Plugin plugin, User user); + void createPlugin(ChatPlugin plugin, User user); - void updatePlugin(Plugin plugin, User user); + void updatePlugin(ChatPlugin plugin, User user); void deletePlugin(Long id); - List getPluginList(); + List getPluginList(); - List fetchPluginDOs(String queryText, String type); + List fetchPluginDOs(String queryText, String type); - List query(PluginQueryReq pluginQueryReq); + List query(PluginQueryReq pluginQueryReq); - Optional getPluginByName(String name); + Optional getPluginByName(String name); - List queryWithAuthCheck(PluginQueryReq pluginQueryReq, User user); + List queryWithAuthCheck(PluginQueryReq pluginQueryReq, User user); - Map getNameToPlugin(); + Map getNameToPlugin(); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/PluginServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/PluginServiceImpl.java index ee02f0016..def5b2e99 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/PluginServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/PluginServiceImpl.java @@ -4,7 +4,7 @@ import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq; -import com.tencent.supersonic.chat.server.plugin.Plugin; +import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent; import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent; @@ -43,19 +43,19 @@ public class PluginServiceImpl implements PluginService { } @Override - public synchronized void createPlugin(Plugin plugin, User user) { + public synchronized void createPlugin(ChatPlugin plugin, User user) { PluginDO pluginDO = convert(plugin, user); pluginRepository.createPlugin(pluginDO); //compatible with H2 db - List plugins = getPluginList(); + List plugins = getPluginList(); publisher.publishEvent(new PluginAddEvent(this, plugins.get(plugins.size() - 1))); } @Override - public void updatePlugin(Plugin plugin, User user) { + public void updatePlugin(ChatPlugin plugin, User user) { Long id = plugin.getId(); PluginDO pluginDO = pluginRepository.getPlugin(id); - Plugin oldPlugin = convert(pluginDO); + ChatPlugin oldPlugin = convert(pluginDO); convert(plugin, pluginDO, user); pluginRepository.updatePlugin(pluginDO); publisher.publishEvent(new PluginUpdateEvent(this, oldPlugin, plugin)); @@ -71,8 +71,8 @@ public class PluginServiceImpl implements PluginService { } @Override - public List getPluginList() { - List plugins = Lists.newArrayList(); + public List getPluginList() { + List plugins = Lists.newArrayList(); List pluginDOS = pluginRepository.getPlugins(); if (CollectionUtils.isEmpty(pluginDOS)) { return plugins; @@ -81,13 +81,13 @@ public class PluginServiceImpl implements PluginService { } @Override - public List fetchPluginDOs(String queryText, String type) { + public List fetchPluginDOs(String queryText, String type) { List pluginDOS = pluginRepository.fetchPluginDOs(queryText, type); return convertList(pluginDOS); } @Override - public List query(PluginQueryReq pluginQueryReq) { + public List query(PluginQueryReq pluginQueryReq) { QueryWrapper queryWrapper = new QueryWrapper<>(); if (StringUtils.isNotBlank(pluginQueryReq.getType())) { @@ -120,7 +120,7 @@ public class PluginServiceImpl implements PluginService { } @Override - public Optional getPluginByName(String name) { + public Optional getPluginByName(String name) { log.info("name:{}", name); return getPluginList().stream() .filter(plugin -> { @@ -133,7 +133,7 @@ public class PluginServiceImpl implements PluginService { .findFirst(); } - private PluginParseConfig getPluginParseConfig(Plugin plugin) { + private PluginParseConfig getPluginParseConfig(ChatPlugin plugin) { if (StringUtils.isBlank(plugin.getParseModeConfig())) { return null; } @@ -149,13 +149,13 @@ public class PluginServiceImpl implements PluginService { } @Override - public List queryWithAuthCheck(PluginQueryReq pluginQueryReq, User user) { + public List queryWithAuthCheck(PluginQueryReq pluginQueryReq, User user) { return authCheck(query(pluginQueryReq), user); } @Override - public Map getNameToPlugin() { - List pluginList = getPluginList(); + public Map getNameToPlugin() { + List pluginList = getPluginList(); return pluginList.stream() .filter(plugin -> { @@ -173,12 +173,12 @@ public class PluginServiceImpl implements PluginService { } //todo - private List authCheck(List plugins, User user) { + private List authCheck(List plugins, User user) { return plugins; } - public Plugin convert(PluginDO pluginDO) { - Plugin plugin = new Plugin(); + public ChatPlugin convert(PluginDO pluginDO) { + ChatPlugin plugin = new ChatPlugin(); BeanUtils.copyProperties(pluginDO, plugin); if (StringUtils.isNotBlank(pluginDO.getDataSet())) { plugin.setDataSetList(Arrays.stream(pluginDO.getDataSet().split(",")) @@ -187,7 +187,7 @@ public class PluginServiceImpl implements PluginService { return plugin; } - public PluginDO convert(Plugin plugin, User user) { + public PluginDO convert(ChatPlugin plugin, User user) { PluginDO pluginDO = new PluginDO(); BeanUtils.copyProperties(plugin, pluginDO); pluginDO.setCreatedAt(new Date()); @@ -198,7 +198,7 @@ public class PluginServiceImpl implements PluginService { return pluginDO; } - public PluginDO convert(Plugin plugin, PluginDO pluginDO, User user) { + public PluginDO convert(ChatPlugin plugin, PluginDO pluginDO, User user) { BeanUtils.copyProperties(plugin, pluginDO); pluginDO.setUpdatedAt(new Date()); pluginDO.setUpdatedBy(user.getName()); @@ -206,7 +206,7 @@ public class PluginServiceImpl implements PluginService { return pluginDO; } - public List convertList(List pluginDOS) { + public List convertList(List pluginDOS) { if (!CollectionUtils.isEmpty(pluginDOS)) { return pluginDOS.stream().map(this::convert).collect(Collectors.toList()); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java index a013c857d..a49311d41 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/LLMConnHelper.java @@ -1,18 +1,25 @@ package com.tencent.supersonic.chat.server.util; import com.tencent.supersonic.common.config.LLMConfig; +import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.util.S2ChatModelProvider; import dev.langchain4j.model.chat.ChatLanguageModel; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; +@Slf4j public class LLMConnHelper { public static boolean testConnection(LLMConfig llmConfig) { try { + if (llmConfig == null || StringUtils.isBlank(llmConfig.getBaseUrl())) { + return false; + } ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(llmConfig); String response = chatLanguageModel.generate("Hi there"); return StringUtils.isNotEmpty(response) ? true : false; } catch (Exception e) { - return false; + log.warn("connect llm failed:", e); + throw new InvalidArgumentException(e.getMessage()); } } } diff --git a/common/src/main/java/com/tencent/supersonic/common/config/LLMConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/LLMConfig.java index 94477f716..1754b24db 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/LLMConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/LLMConfig.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.common.config; +import com.tencent.supersonic.common.util.AESUtil; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; @@ -36,4 +37,9 @@ public class LLMConfig { this.modelName = modelName; this.temperature = temperature; } + + public String keyDecrypt() { + return AESUtil.aesDecrypt(apiKey); + } + } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/AESUtil.java b/common/src/main/java/com/tencent/supersonic/common/util/AESUtil.java new file mode 100644 index 000000000..d519fc537 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/AESUtil.java @@ -0,0 +1,84 @@ +package com.tencent.supersonic.common.util; + + +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.codec.binary.Base64; +import org.apache.commons.lang3.StringUtils; +import sun.misc.BASE64Decoder; + +import javax.crypto.Cipher; +import javax.crypto.KeyGenerator; +import javax.crypto.spec.SecretKeySpec; + +@Slf4j +public class AESUtil { + + private static final String KEY = "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08"; + //算法 + private static final String ALGORITHMSTR = "AES/ECB/PKCS5Padding"; + + public static String aesDecrypt(String encrypt) { + try { + return aesDecrypt(encrypt, KEY); + } catch (Exception e) { + log.warn("content decrypt failed:{}", encrypt); + return encrypt; + } + } + + private static String aesDecrypt(String encryptStr, String decryptKey) throws Exception { + return StringUtils.isEmpty(encryptStr) ? null : aesDecryptByBytes(base64Decode(encryptStr), decryptKey); + } + + private static String base64Encode(byte[] bytes) { + return Base64.encodeBase64String(bytes); + } + + private static byte[] base64Decode(String base64Code) throws Exception { + return StringUtils.isEmpty(base64Code) ? null : new BASE64Decoder().decodeBuffer(base64Code); + } + + private static byte[] aesEncryptToBytes(String content, String encryptKey) throws Exception { + KeyGenerator kgen = KeyGenerator.getInstance("AES"); + kgen.init(128); + Cipher cipher = Cipher.getInstance(ALGORITHMSTR); + cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(hexStringToByteArray(encryptKey), "AES")); + + return cipher.doFinal(content.getBytes("utf-8")); + } + + private static String aesEncrypt(String content, String encryptKey) throws Exception { + return base64Encode(aesEncryptToBytes(content, encryptKey)); + } + + private static String aesDecryptByBytes(byte[] encryptBytes, String decryptKey) throws Exception { + KeyGenerator kgen = KeyGenerator.getInstance("AES"); + kgen.init(128); + + Cipher cipher = Cipher.getInstance(ALGORITHMSTR); + cipher.init(Cipher.DECRYPT_MODE, new SecretKeySpec(hexStringToByteArray(decryptKey), "AES")); + byte[] decryptBytes = cipher.doFinal(encryptBytes); + return new String(decryptBytes); + } + + public static byte[] hexStringToByteArray(String hexString) { + int len = hexString.length(); + byte[] byteArray = new byte[len / 2]; + for (int i = 0; i < len; i += 2) { + byteArray[i / 2] = (byte) ((Character.digit(hexString.charAt(i), 16) << 4) + + Character.digit(hexString.charAt(i + 1), 16)); + } + return byteArray; + } + + public static void main(String[] args) throws Exception { + String content = "123"; + System.out.println("before encrypt:" + content); + System.out.println("key:" + KEY); + String encrypt = aesEncrypt(content, KEY); + System.out.println("after encrypt:" + encrypt); + String decrypt = aesDecrypt(encrypt); + System.out.println("after decrypt:" + decrypt); + } + +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/S2ChatModelProvider.java b/common/src/main/java/com/tencent/supersonic/common/util/S2ChatModelProvider.java index 7a9ce310b..e31292c49 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/S2ChatModelProvider.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/S2ChatModelProvider.java @@ -22,7 +22,7 @@ public class S2ChatModelProvider { .builder() .baseUrl(llmConfig.getBaseUrl()) .modelName(llmConfig.getModelName()) - .apiKey(llmConfig.getApiKey()) + .apiKey(llmConfig.keyDecrypt()) .temperature(llmConfig.getTemperature()) .timeout(Duration.ofSeconds(llmConfig.getTimeOut())) .build(); diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index 0b7a3a862..5bce0a209 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -11,7 +11,7 @@ import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.LLMParserTool; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.agent.RuleParserTool; -import com.tencent.supersonic.chat.server.plugin.Plugin; +import com.tencent.supersonic.chat.server.plugin.ChatPlugin; import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; import com.tencent.supersonic.chat.server.plugin.build.WebBase; import com.tencent.supersonic.common.pojo.JoinCondition; @@ -514,7 +514,7 @@ public class S2VisitsDemo extends S2BaseDemo { } private void addPlugin(DataSetResp s2DataSet) { - Plugin plugin1 = new Plugin(); + ChatPlugin plugin1 = new ChatPlugin(); plugin1.setType("WEB_PAGE"); plugin1.setDataSetList(Arrays.asList(s2DataSet.getId())); plugin1.setPattern("用于分析超音数的流量概况,包含UV、PV等核心指标的追踪。P.S. 仅作为示例展示,无实际看板");