diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/agent/Agent.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/agent/Agent.java index e9746c396..7f74b798e 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/agent/Agent.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/agent/Agent.java @@ -22,7 +22,9 @@ public class Agent extends RecordInfo { private String name; private String description; - //0 offline, 1 online + /** + * 0 offline, 1 online + */ private Integer status; private List examples; private String agentConfig; @@ -49,7 +51,7 @@ public class Agent extends RecordInfo { return enableSearch != null && enableSearch == 1; } - public boolean containsAllModel(Set detectModelIds) { + public static boolean containsAllModel(Set detectModelIds) { return !CollectionUtils.isEmpty(detectModelIds) && detectModelIds.contains(-1L); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/MapperHelper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/MapperHelper.java index d297f22d7..7c0e01bbc 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/MapperHelper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/mapper/MapperHelper.java @@ -85,9 +85,12 @@ public class MapperHelper { public Set getModelIds(QueryReq request, Agent agent) { Long modelId = request.getModelId(); - Set detectModelIds = agent.getModelIds(null); + Set detectModelIds = new HashSet<>(); + if (Objects.nonNull(agent)) { + detectModelIds = agent.getModelIds(null); + } //contains all - if (agent.containsAllModel(detectModelIds)) { + if (Agent.containsAllModel(detectModelIds)) { if (Objects.nonNull(modelId) && modelId > 0) { Set result = new HashSet<>(); result.add(modelId); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java index 4d7cba0cc..990bd5863 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/core/parser/sql/llm/LLMRequestService.java @@ -61,8 +61,12 @@ public class LLMRequestService { } public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx) { - Set distinctModelIds = queryCtx.getAgent().getModelIds(AgentToolType.NL2SQL_LLM); - if (queryCtx.getAgent().containsAllModel(distinctModelIds)) { + Agent agent = queryCtx.getAgent(); + Set distinctModelIds = new HashSet<>(); + if (Objects.nonNull(agent)) { + distinctModelIds = agent.getModelIds(AgentToolType.NL2SQL_LLM); + } + if (Agent.containsAllModel(distinctModelIds)) { distinctModelIds = new HashSet<>(); } ModelResolver modelResolver = ComponentFactory.getModelResolver(); @@ -77,7 +81,7 @@ public class LLMRequestService { Optional llmParserTool = commonAgentTools.stream() .filter(tool -> { List modelIds = tool.getModelIds(); - if (agent.containsAllModel(new HashSet<>(modelIds))) { + if (Agent.containsAllModel(new HashSet<>(modelIds))) { return true; } for (Long modelId : modelIdSet) {