diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index 0374f10a3..36da61d57 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -1,9 +1,11 @@ package com.tencent.supersonic.chat.server.agent; import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; import com.tencent.supersonic.chat.server.memory.MemoryReviewTask; import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.RecordInfo; +import com.tencent.supersonic.common.pojo.User; import lombok.Data; import org.springframework.util.CollectionUtils; @@ -12,6 +14,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; @Data @@ -33,6 +36,8 @@ public class Agent extends RecordInfo { private String toolConfig; private Map chatAppConfig = Collections.emptyMap(); private VisualConfig visualConfig; + private List admins = Lists.newArrayList(); + private List viewers = Lists.newArrayList(); public List getTools(AgentToolType type) { Map map = JSONObject.parseObject(toolConfig, Map.class); @@ -105,4 +110,9 @@ public class Agent extends RecordInfo { .filter(dataSetIds -> !CollectionUtils.isEmpty(dataSetIds)) .flatMap(Collection::stream).collect(Collectors.toSet()); } + + public boolean contains(User user, Function> list) { + return list.apply(this).contains(user.getName()); + } + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java index 58645c621..a71596e82 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/AgentDO.java @@ -40,4 +40,8 @@ public class AgentDO { private String chatModelConfig; private String visualConfig; + + private String admin; + + private String viewer; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java index 0cb4ddee0..ff3e01b18 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java @@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.common.pojo.User; +import com.tencent.supersonic.common.pojo.enums.AuthType; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.PathVariable; @@ -15,6 +16,7 @@ import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.PutMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; import java.util.List; @@ -48,8 +50,11 @@ public class AgentController { } @RequestMapping("/getAgentList") - public List getAgentList() { - return agentService.getAgents(); + public List getAgentList( + @RequestParam(value = "authType", required = false) AuthType authType, + HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { + User user = UserHolder.findUser(httpServletRequest, httpServletResponse); + return agentService.getAgents(user, authType); } @RequestMapping("/getToolTypes") diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/AgentService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/AgentService.java index 147ff2615..d5d24fcb7 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/AgentService.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/AgentService.java @@ -2,10 +2,12 @@ package com.tencent.supersonic.chat.server.service; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.common.pojo.User; +import com.tencent.supersonic.common.pojo.enums.AuthType; import java.util.List; public interface AgentService { + List getAgents(User user, AuthType authType); List getAgents(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index c048c59fc..591d3a21b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -14,6 +14,7 @@ import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.common.config.ChatModel; import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.User; +import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.common.service.ChatModelService; import com.tencent.supersonic.common.util.JsonUtil; import lombok.extern.slf4j.Slf4j; @@ -43,6 +44,27 @@ public class AgentServiceImpl extends ServiceImpl implem private ExecutorService executorService = Executors.newFixedThreadPool(1); + @Override + public List getAgents(User user, AuthType authType) { + return getAgentDOList().stream().map(this::convert) + .filter(agent -> filterByAuth(agent, user, authType)).collect(Collectors.toList()); + } + + private boolean filterByAuth(Agent agent, User user, AuthType authType) { + if (user.isSuperAdmin() || user.getName().equals(agent.getCreatedBy())) { + return true; + } + authType = authType == null ? AuthType.VIEWER : authType; + switch (authType) { + case ADMIN: + return agent.contains(user, Agent::getAdmins); + case VIEWER: + default: + return agent.contains(user, Agent::getAdmins) + || agent.contains(user, Agent::getViewers); + } + } + @Override public List getAgents() { return getAgentDOList().stream().map(this::convert).collect(Collectors.toList()); @@ -135,6 +157,8 @@ public class AgentServiceImpl extends ServiceImpl implem c.setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig()); } }); + agent.setAdmins(JsonUtil.toList(agentDO.getAdmin(), String.class)); + agent.setViewers(JsonUtil.toList(agentDO.getViewer(), String.class)); return agent; } @@ -145,6 +169,8 @@ public class AgentServiceImpl extends ServiceImpl implem agentDO.setExamples(JsonUtil.toString(agent.getExamples())); agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatAppConfig())); agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig())); + agentDO.setAdmin(JsonUtil.toString(agent.getAdmins())); + agentDO.setViewer(JsonUtil.toString(agent.getViewers())); if (agentDO.getStatus() == null) { agentDO.setStatus(1); } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AuthType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AuthType.java index 0c505a120..3df7e526b 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AuthType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AuthType.java @@ -1,5 +1,5 @@ package com.tencent.supersonic.common.pojo.enums; public enum AuthType { - VISIBLE, ADMIN + VIEWER, ADMIN } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java index fdd2750c0..2cb623acc 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java @@ -260,9 +260,8 @@ public class S2DataPermissionAspect { } public void checkModelVisible(User user, Set modelIds) { - List modelListVisible = - modelService.getModelListWithAuth(user, null, AuthType.VISIBLE).stream() - .map(ModelResp::getId).collect(Collectors.toList()); + List modelListVisible = modelService.getModelListWithAuth(user, null, AuthType.VIEWER) + .stream().map(ModelResp::getId).collect(Collectors.toList()); List modelIdCopied = new ArrayList<>(modelIds); modelIdCopied.removeAll(modelListVisible); if (!CollectionUtils.isEmpty(modelIdCopied)) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DomainServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DomainServiceImpl.java index fd9f69d43..be75ac482 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DomainServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DomainServiceImpl.java @@ -126,7 +126,7 @@ public class DomainServiceImpl implements DomainService { return domainWithAuth.stream().peek(domainResp -> domainResp.setHasEditPermission(true)) .collect(Collectors.toSet()); } - if (authTypeEnum.equals(AuthType.VISIBLE)) { + if (authTypeEnum.equals(AuthType.VIEWER)) { domainWithAuth = domainResps.stream() .filter(domainResp -> checkViewPermission(orgIds, user, domainResp)) .collect(Collectors.toSet()); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java index f59d7e64e..e1c64a445 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java @@ -428,7 +428,7 @@ public class ModelServiceImpl implements ModelService { .filter(modelResp -> checkAdminPermission(orgIds, user, modelResp)) .collect(Collectors.toList()); } - if (authTypeEnum.equals(AuthType.VISIBLE)) { + if (authTypeEnum.equals(AuthType.VIEWER)) { modelWithAuth = modelResps.stream() .filter(domainResp -> checkDataSetPermission(orgIds, user, domainResp)) .collect(Collectors.toList()); diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2SingerDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2SingerDemo.java index 3659659b7..1b40be2d5 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2SingerDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2SingerDemo.java @@ -167,6 +167,8 @@ public class S2SingerDemo extends S2BaseDemo { Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT)); chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId())); agent.setChatAppConfig(chatAppConfig); + agent.setAdmins(Lists.newArrayList("alice")); + agent.setViewers(Lists.newArrayList("tom", "jack")); agentService.createAgent(agent, defaultUser); } } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2SmallTalkDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2SmallTalkDemo.java index 943ed9c3f..7964d0bf8 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2SmallTalkDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2SmallTalkDemo.java @@ -40,6 +40,8 @@ public class S2SmallTalkDemo extends S2BaseDemo { chatAppConfig.get(PlainTextExecutor.APP_KEY).setEnable(true); chatAppConfig.get(OnePassSCSqlGenStrategy.APP_KEY).setEnable(false); agent.setChatAppConfig(chatAppConfig); + agent.setAdmins(Lists.newArrayList("jack")); + agent.setViewers(Lists.newArrayList("alice", "tom")); agentService.createAgent(agent, defaultUser); } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index 29748cb4c..87b282372 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -162,6 +162,8 @@ public class S2VisitsDemo extends S2BaseDemo { Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT)); chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId())); agent.setChatAppConfig(chatAppConfig); + agent.setAdmins(Lists.newArrayList("tom")); + agent.setViewers(Lists.newArrayList("alice", "jack")); Agent agentCreated = agentService.createAgent(agent, defaultUser); return agentCreated.getId(); } diff --git a/launchers/standalone/src/main/resources/config.update/sql-update.sql b/launchers/standalone/src/main/resources/config.update/sql-update.sql index 0cefd0fb6..17c43e127 100644 --- a/launchers/standalone/src/main/resources/config.update/sql-update.sql +++ b/launchers/standalone/src/main/resources/config.update/sql-update.sql @@ -393,4 +393,8 @@ ALTER TABLE s2_agent DROP COLUMN `multi_turn_config`; ALTER TABLE s2_agent DROP COLUMN `enable_memory_review`; --20241012 -alter table s2_agent add column `enable_feedback` tinyint DEFAULT 1; \ No newline at end of file +alter table s2_agent add column `enable_feedback` tinyint DEFAULT 1; + +--20241116 +alter table s2_agent add column `admin` varchar(1000); +alter table s2_agent add column `viewer` varchar(1000); \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index 71985491c..4da6f27c2 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -398,6 +398,8 @@ CREATE TABLE IF NOT EXISTS s2_agent updated_at TIMESTAMP null, enable_search int null, enable_feedback int null, + admin varchar(1000), + viewer varchar(1000), PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_agent IS 'agent information table'; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SchemaAuthTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SchemaAuthTest.java index d813ef0aa..661262d50 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SchemaAuthTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SchemaAuthTest.java @@ -51,7 +51,7 @@ public class SchemaAuthTest extends BaseTest { public void test_getVisibleModelList_alice() { User user = DataUtils.getUserAlice(); List modelResps = - modelService.getModelListWithAuth(user, null, AuthType.VISIBLE); + modelService.getModelListWithAuth(user, null, AuthType.VIEWER); List expectedModelBizNames = Lists.newArrayList("user_department", "singer"); Assertions.assertEquals(expectedModelBizNames, modelResps.stream().map(ModelResp::getBizName).collect(Collectors.toList())); diff --git a/launchers/standalone/src/test/resources/db/schema-h2.sql b/launchers/standalone/src/test/resources/db/schema-h2.sql index 7439debac..8584c6239 100644 --- a/launchers/standalone/src/test/resources/db/schema-h2.sql +++ b/launchers/standalone/src/test/resources/db/schema-h2.sql @@ -398,6 +398,8 @@ CREATE TABLE IF NOT EXISTS s2_agent updated_at TIMESTAMP null, enable_search int null, enable_feedback int null, + admin varchar(1000), + viewer varchar(1000), PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_agent IS 'agent information table';