(improvement)(chat) Make it compatible with the case when the agent is empty. (#596)

This commit is contained in:
lexluo09
2024-01-04 21:39:56 +08:00
committed by GitHub
parent ade96c3adc
commit 602b9547b8
3 changed files with 16 additions and 7 deletions

View File

@@ -22,7 +22,9 @@ public class Agent extends RecordInfo {
private String name;
private String description;
//0 offline, 1 online
/**
* 0 offline, 1 online
*/
private Integer status;
private List<String> examples;
private String agentConfig;
@@ -49,7 +51,7 @@ public class Agent extends RecordInfo {
return enableSearch != null && enableSearch == 1;
}
public boolean containsAllModel(Set<Long> detectModelIds) {
public static boolean containsAllModel(Set<Long> detectModelIds) {
return !CollectionUtils.isEmpty(detectModelIds) && detectModelIds.contains(-1L);
}

View File

@@ -85,9 +85,12 @@ public class MapperHelper {
public Set<Long> getModelIds(QueryReq request, Agent agent) {
Long modelId = request.getModelId();
Set<Long> detectModelIds = agent.getModelIds(null);
Set<Long> detectModelIds = new HashSet<>();
if (Objects.nonNull(agent)) {
detectModelIds = agent.getModelIds(null);
}
//contains all
if (agent.containsAllModel(detectModelIds)) {
if (Agent.containsAllModel(detectModelIds)) {
if (Objects.nonNull(modelId) && modelId > 0) {
Set<Long> result = new HashSet<>();
result.add(modelId);

View File

@@ -61,8 +61,12 @@ public class LLMRequestService {
}
public ModelCluster getModelCluster(QueryContext queryCtx, ChatContext chatCtx) {
Set<Long> distinctModelIds = queryCtx.getAgent().getModelIds(AgentToolType.NL2SQL_LLM);
if (queryCtx.getAgent().containsAllModel(distinctModelIds)) {
Agent agent = queryCtx.getAgent();
Set<Long> distinctModelIds = new HashSet<>();
if (Objects.nonNull(agent)) {
distinctModelIds = agent.getModelIds(AgentToolType.NL2SQL_LLM);
}
if (Agent.containsAllModel(distinctModelIds)) {
distinctModelIds = new HashSet<>();
}
ModelResolver modelResolver = ComponentFactory.getModelResolver();
@@ -77,7 +81,7 @@ public class LLMRequestService {
Optional<NL2SQLTool> llmParserTool = commonAgentTools.stream()
.filter(tool -> {
List<Long> modelIds = tool.getModelIds();
if (agent.containsAllModel(new HashSet<>(modelIds))) {
if (Agent.containsAllModel(new HashSet<>(modelIds))) {
return true;
}
for (Long modelId : modelIdSet) {