mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][Chat] Support agent permission management #1143
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.server.memory.MemoryReviewTask;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -12,6 +14,7 @@ import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
@@ -33,6 +36,8 @@ public class Agent extends RecordInfo {
|
||||
private String toolConfig;
|
||||
private Map<String, ChatApp> chatAppConfig = Collections.emptyMap();
|
||||
private VisualConfig visualConfig;
|
||||
private List<String> admins = Lists.newArrayList();
|
||||
private List<String> viewers = Lists.newArrayList();
|
||||
|
||||
public List<String> getTools(AgentToolType type) {
|
||||
Map<String, Object> map = JSONObject.parseObject(toolConfig, Map.class);
|
||||
@@ -105,4 +110,9 @@ public class Agent extends RecordInfo {
|
||||
.filter(dataSetIds -> !CollectionUtils.isEmpty(dataSetIds))
|
||||
.flatMap(Collection::stream).collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
public boolean contains(User user, Function<Agent, List<String>> list) {
|
||||
return list.apply(this).contains(user.getName());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -40,4 +40,8 @@ public class AgentDO {
|
||||
private String chatModelConfig;
|
||||
|
||||
private String visualConfig;
|
||||
|
||||
private String admin;
|
||||
|
||||
private String viewer;
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
@@ -15,6 +16,7 @@ import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.PutMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import java.util.List;
|
||||
@@ -48,8 +50,11 @@ public class AgentController {
|
||||
}
|
||||
|
||||
@RequestMapping("/getAgentList")
|
||||
public List<Agent> getAgentList() {
|
||||
return agentService.getAgents();
|
||||
public List<Agent> getAgentList(
|
||||
@RequestParam(value = "authType", required = false) AuthType authType,
|
||||
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||
return agentService.getAgents(user, authType);
|
||||
}
|
||||
|
||||
@RequestMapping("/getToolTypes")
|
||||
|
||||
@@ -2,10 +2,12 @@ package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface AgentService {
|
||||
List<Agent> getAgents(User user, AuthType authType);
|
||||
|
||||
List<Agent> getAgents();
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.config.ChatModel;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
import com.tencent.supersonic.common.service.ChatModelService;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -43,6 +44,27 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
|
||||
private ExecutorService executorService = Executors.newFixedThreadPool(1);
|
||||
|
||||
@Override
|
||||
public List<Agent> getAgents(User user, AuthType authType) {
|
||||
return getAgentDOList().stream().map(this::convert)
|
||||
.filter(agent -> filterByAuth(agent, user, authType)).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private boolean filterByAuth(Agent agent, User user, AuthType authType) {
|
||||
if (user.isSuperAdmin() || user.getName().equals(agent.getCreatedBy())) {
|
||||
return true;
|
||||
}
|
||||
authType = authType == null ? AuthType.VIEWER : authType;
|
||||
switch (authType) {
|
||||
case ADMIN:
|
||||
return agent.contains(user, Agent::getAdmins);
|
||||
case VIEWER:
|
||||
default:
|
||||
return agent.contains(user, Agent::getAdmins)
|
||||
|| agent.contains(user, Agent::getViewers);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Agent> getAgents() {
|
||||
return getAgentDOList().stream().map(this::convert).collect(Collectors.toList());
|
||||
@@ -135,6 +157,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
c.setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig());
|
||||
}
|
||||
});
|
||||
agent.setAdmins(JsonUtil.toList(agentDO.getAdmin(), String.class));
|
||||
agent.setViewers(JsonUtil.toList(agentDO.getViewer(), String.class));
|
||||
return agent;
|
||||
}
|
||||
|
||||
@@ -145,6 +169,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
|
||||
agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatAppConfig()));
|
||||
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
|
||||
agentDO.setAdmin(JsonUtil.toString(agent.getAdmins()));
|
||||
agentDO.setViewer(JsonUtil.toString(agent.getViewers()));
|
||||
if (agentDO.getStatus() == null) {
|
||||
agentDO.setStatus(1);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
public enum AuthType {
|
||||
VISIBLE, ADMIN
|
||||
VIEWER, ADMIN
|
||||
}
|
||||
|
||||
@@ -260,9 +260,8 @@ public class S2DataPermissionAspect {
|
||||
}
|
||||
|
||||
public void checkModelVisible(User user, Set<Long> modelIds) {
|
||||
List<Long> modelListVisible =
|
||||
modelService.getModelListWithAuth(user, null, AuthType.VISIBLE).stream()
|
||||
.map(ModelResp::getId).collect(Collectors.toList());
|
||||
List<Long> modelListVisible = modelService.getModelListWithAuth(user, null, AuthType.VIEWER)
|
||||
.stream().map(ModelResp::getId).collect(Collectors.toList());
|
||||
List<Long> modelIdCopied = new ArrayList<>(modelIds);
|
||||
modelIdCopied.removeAll(modelListVisible);
|
||||
if (!CollectionUtils.isEmpty(modelIdCopied)) {
|
||||
|
||||
@@ -126,7 +126,7 @@ public class DomainServiceImpl implements DomainService {
|
||||
return domainWithAuth.stream().peek(domainResp -> domainResp.setHasEditPermission(true))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
if (authTypeEnum.equals(AuthType.VISIBLE)) {
|
||||
if (authTypeEnum.equals(AuthType.VIEWER)) {
|
||||
domainWithAuth = domainResps.stream()
|
||||
.filter(domainResp -> checkViewPermission(orgIds, user, domainResp))
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
@@ -428,7 +428,7 @@ public class ModelServiceImpl implements ModelService {
|
||||
.filter(modelResp -> checkAdminPermission(orgIds, user, modelResp))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
if (authTypeEnum.equals(AuthType.VISIBLE)) {
|
||||
if (authTypeEnum.equals(AuthType.VIEWER)) {
|
||||
modelWithAuth = modelResps.stream()
|
||||
.filter(domainResp -> checkDataSetPermission(orgIds, user, domainResp))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
@@ -167,6 +167,8 @@ public class S2SingerDemo extends S2BaseDemo {
|
||||
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
|
||||
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
|
||||
agent.setChatAppConfig(chatAppConfig);
|
||||
agent.setAdmins(Lists.newArrayList("alice"));
|
||||
agent.setViewers(Lists.newArrayList("tom", "jack"));
|
||||
agentService.createAgent(agent, defaultUser);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,6 +40,8 @@ public class S2SmallTalkDemo extends S2BaseDemo {
|
||||
chatAppConfig.get(PlainTextExecutor.APP_KEY).setEnable(true);
|
||||
chatAppConfig.get(OnePassSCSqlGenStrategy.APP_KEY).setEnable(false);
|
||||
agent.setChatAppConfig(chatAppConfig);
|
||||
agent.setAdmins(Lists.newArrayList("jack"));
|
||||
agent.setViewers(Lists.newArrayList("alice", "tom"));
|
||||
agentService.createAgent(agent, defaultUser);
|
||||
}
|
||||
|
||||
|
||||
@@ -162,6 +162,8 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
|
||||
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
|
||||
agent.setChatAppConfig(chatAppConfig);
|
||||
agent.setAdmins(Lists.newArrayList("tom"));
|
||||
agent.setViewers(Lists.newArrayList("alice", "jack"));
|
||||
Agent agentCreated = agentService.createAgent(agent, defaultUser);
|
||||
return agentCreated.getId();
|
||||
}
|
||||
|
||||
@@ -393,4 +393,8 @@ ALTER TABLE s2_agent DROP COLUMN `multi_turn_config`;
|
||||
ALTER TABLE s2_agent DROP COLUMN `enable_memory_review`;
|
||||
|
||||
--20241012
|
||||
alter table s2_agent add column `enable_feedback` tinyint DEFAULT 1;
|
||||
alter table s2_agent add column `enable_feedback` tinyint DEFAULT 1;
|
||||
|
||||
--20241116
|
||||
alter table s2_agent add column `admin` varchar(1000);
|
||||
alter table s2_agent add column `viewer` varchar(1000);
|
||||
@@ -398,6 +398,8 @@ CREATE TABLE IF NOT EXISTS s2_agent
|
||||
updated_at TIMESTAMP null,
|
||||
enable_search int null,
|
||||
enable_feedback int null,
|
||||
admin varchar(1000),
|
||||
viewer varchar(1000),
|
||||
PRIMARY KEY (`id`)
|
||||
); COMMENT ON TABLE s2_agent IS 'agent information table';
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ public class SchemaAuthTest extends BaseTest {
|
||||
public void test_getVisibleModelList_alice() {
|
||||
User user = DataUtils.getUserAlice();
|
||||
List<ModelResp> modelResps =
|
||||
modelService.getModelListWithAuth(user, null, AuthType.VISIBLE);
|
||||
modelService.getModelListWithAuth(user, null, AuthType.VIEWER);
|
||||
List<String> expectedModelBizNames = Lists.newArrayList("user_department", "singer");
|
||||
Assertions.assertEquals(expectedModelBizNames,
|
||||
modelResps.stream().map(ModelResp::getBizName).collect(Collectors.toList()));
|
||||
|
||||
@@ -398,6 +398,8 @@ CREATE TABLE IF NOT EXISTS s2_agent
|
||||
updated_at TIMESTAMP null,
|
||||
enable_search int null,
|
||||
enable_feedback int null,
|
||||
admin varchar(1000),
|
||||
viewer varchar(1000),
|
||||
PRIMARY KEY (`id`)
|
||||
); COMMENT ON TABLE s2_agent IS 'agent information table';
|
||||
|
||||
|
||||
Reference in New Issue
Block a user