[improvement][Chat] Support agent permission management #1143

This commit is contained in:
lxwcodemonkey
2024-11-16 21:44:50 +08:00
parent e8c9855163
commit 36d221ab74
16 changed files with 70 additions and 10 deletions

View File

@@ -1,9 +1,11 @@
package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.memory.MemoryReviewTask;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.User;
import lombok.Data;
import org.springframework.util.CollectionUtils;
@@ -12,6 +14,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
@Data
@@ -33,6 +36,8 @@ public class Agent extends RecordInfo {
private String toolConfig;
private Map<String, ChatApp> chatAppConfig = Collections.emptyMap();
private VisualConfig visualConfig;
private List<String> admins = Lists.newArrayList();
private List<String> viewers = Lists.newArrayList();
public List<String> getTools(AgentToolType type) {
Map<String, Object> map = JSONObject.parseObject(toolConfig, Map.class);
@@ -105,4 +110,9 @@ public class Agent extends RecordInfo {
.filter(dataSetIds -> !CollectionUtils.isEmpty(dataSetIds))
.flatMap(Collection::stream).collect(Collectors.toSet());
}
public boolean contains(User user, Function<Agent, List<String>> list) {
return list.apply(this).contains(user.getName());
}
}

View File

@@ -40,4 +40,8 @@ public class AgentDO {
private String chatModelConfig;
private String visualConfig;
private String admin;
private String viewer;
}

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.DeleteMapping;
import org.springframework.web.bind.annotation.PathVariable;
@@ -15,6 +16,7 @@ import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.PutMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import java.util.List;
@@ -48,8 +50,11 @@ public class AgentController {
}
@RequestMapping("/getAgentList")
public List<Agent> getAgentList() {
return agentService.getAgents();
public List<Agent> getAgentList(
@RequestParam(value = "authType", required = false) AuthType authType,
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
return agentService.getAgents(user, authType);
}
@RequestMapping("/getToolTypes")

View File

@@ -2,10 +2,12 @@ package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import java.util.List;
public interface AgentService {
List<Agent> getAgents(User user, AuthType authType);
List<Agent> getAgents();

View File

@@ -14,6 +14,7 @@ import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.common.config.ChatModel;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.common.service.ChatModelService;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
@@ -43,6 +44,27 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
private ExecutorService executorService = Executors.newFixedThreadPool(1);
@Override
public List<Agent> getAgents(User user, AuthType authType) {
return getAgentDOList().stream().map(this::convert)
.filter(agent -> filterByAuth(agent, user, authType)).collect(Collectors.toList());
}
private boolean filterByAuth(Agent agent, User user, AuthType authType) {
if (user.isSuperAdmin() || user.getName().equals(agent.getCreatedBy())) {
return true;
}
authType = authType == null ? AuthType.VIEWER : authType;
switch (authType) {
case ADMIN:
return agent.contains(user, Agent::getAdmins);
case VIEWER:
default:
return agent.contains(user, Agent::getAdmins)
|| agent.contains(user, Agent::getViewers);
}
}
@Override
public List<Agent> getAgents() {
return getAgentDOList().stream().map(this::convert).collect(Collectors.toList());
@@ -135,6 +157,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
c.setChatModelConfig(chatModelService.getChatModel(c.getChatModelId()).getConfig());
}
});
agent.setAdmins(JsonUtil.toList(agentDO.getAdmin(), String.class));
agent.setViewers(JsonUtil.toList(agentDO.getViewer(), String.class));
return agent;
}
@@ -145,6 +169,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatAppConfig()));
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
agentDO.setAdmin(JsonUtil.toString(agent.getAdmins()));
agentDO.setViewer(JsonUtil.toString(agent.getViewers()));
if (agentDO.getStatus() == null) {
agentDO.setStatus(1);
}

View File

@@ -1,5 +1,5 @@
package com.tencent.supersonic.common.pojo.enums;
public enum AuthType {
VISIBLE, ADMIN
VIEWER, ADMIN
}

View File

@@ -260,9 +260,8 @@ public class S2DataPermissionAspect {
}
public void checkModelVisible(User user, Set<Long> modelIds) {
List<Long> modelListVisible =
modelService.getModelListWithAuth(user, null, AuthType.VISIBLE).stream()
.map(ModelResp::getId).collect(Collectors.toList());
List<Long> modelListVisible = modelService.getModelListWithAuth(user, null, AuthType.VIEWER)
.stream().map(ModelResp::getId).collect(Collectors.toList());
List<Long> modelIdCopied = new ArrayList<>(modelIds);
modelIdCopied.removeAll(modelListVisible);
if (!CollectionUtils.isEmpty(modelIdCopied)) {

View File

@@ -126,7 +126,7 @@ public class DomainServiceImpl implements DomainService {
return domainWithAuth.stream().peek(domainResp -> domainResp.setHasEditPermission(true))
.collect(Collectors.toSet());
}
if (authTypeEnum.equals(AuthType.VISIBLE)) {
if (authTypeEnum.equals(AuthType.VIEWER)) {
domainWithAuth = domainResps.stream()
.filter(domainResp -> checkViewPermission(orgIds, user, domainResp))
.collect(Collectors.toSet());

View File

@@ -428,7 +428,7 @@ public class ModelServiceImpl implements ModelService {
.filter(modelResp -> checkAdminPermission(orgIds, user, modelResp))
.collect(Collectors.toList());
}
if (authTypeEnum.equals(AuthType.VISIBLE)) {
if (authTypeEnum.equals(AuthType.VIEWER)) {
modelWithAuth = modelResps.stream()
.filter(domainResp -> checkDataSetPermission(orgIds, user, domainResp))
.collect(Collectors.toList());

View File

@@ -167,6 +167,8 @@ public class S2SingerDemo extends S2BaseDemo {
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
agent.setChatAppConfig(chatAppConfig);
agent.setAdmins(Lists.newArrayList("alice"));
agent.setViewers(Lists.newArrayList("tom", "jack"));
agentService.createAgent(agent, defaultUser);
}
}

View File

@@ -40,6 +40,8 @@ public class S2SmallTalkDemo extends S2BaseDemo {
chatAppConfig.get(PlainTextExecutor.APP_KEY).setEnable(true);
chatAppConfig.get(OnePassSCSqlGenStrategy.APP_KEY).setEnable(false);
agent.setChatAppConfig(chatAppConfig);
agent.setAdmins(Lists.newArrayList("jack"));
agent.setViewers(Lists.newArrayList("alice", "tom"));
agentService.createAgent(agent, defaultUser);
}

View File

@@ -162,6 +162,8 @@ public class S2VisitsDemo extends S2BaseDemo {
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
agent.setChatAppConfig(chatAppConfig);
agent.setAdmins(Lists.newArrayList("tom"));
agent.setViewers(Lists.newArrayList("alice", "jack"));
Agent agentCreated = agentService.createAgent(agent, defaultUser);
return agentCreated.getId();
}

View File

@@ -393,4 +393,8 @@ ALTER TABLE s2_agent DROP COLUMN `multi_turn_config`;
ALTER TABLE s2_agent DROP COLUMN `enable_memory_review`;
--20241012
alter table s2_agent add column `enable_feedback` tinyint DEFAULT 1;
alter table s2_agent add column `enable_feedback` tinyint DEFAULT 1;
--20241116
alter table s2_agent add column `admin` varchar(1000);
alter table s2_agent add column `viewer` varchar(1000);

View File

@@ -398,6 +398,8 @@ CREATE TABLE IF NOT EXISTS s2_agent
updated_at TIMESTAMP null,
enable_search int null,
enable_feedback int null,
admin varchar(1000),
viewer varchar(1000),
PRIMARY KEY (`id`)
); COMMENT ON TABLE s2_agent IS 'agent information table';

View File

@@ -51,7 +51,7 @@ public class SchemaAuthTest extends BaseTest {
public void test_getVisibleModelList_alice() {
User user = DataUtils.getUserAlice();
List<ModelResp> modelResps =
modelService.getModelListWithAuth(user, null, AuthType.VISIBLE);
modelService.getModelListWithAuth(user, null, AuthType.VIEWER);
List<String> expectedModelBizNames = Lists.newArrayList("user_department", "singer");
Assertions.assertEquals(expectedModelBizNames,
modelResps.stream().map(ModelResp::getBizName).collect(Collectors.toList()));

View File

@@ -398,6 +398,8 @@ CREATE TABLE IF NOT EXISTS s2_agent
updated_at TIMESTAMP null,
enable_search int null,
enable_feedback int null,
admin varchar(1000),
viewer varchar(1000),
PRIMARY KEY (`id`)
); COMMENT ON TABLE s2_agent IS 'agent information table';