From 49ebb70cb364e8e2bd29e910744d1150cfc94b59 Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Mon, 27 Nov 2023 15:02:20 +0800 Subject: [PATCH] (improvement)(chat) Modify the query types supported by agent rule-type tools to metric type and tag types (#424) Co-authored-by: jolunoluo --- .../supersonic/chat/agent/tool/RuleQueryTool.java | 2 ++ .../supersonic/chat/parser/rule/AgentCheckParser.java | 10 ++++++++++ .../tencent/supersonic/chat/query/QueryManager.java | 2 +- .../supersonic/chat/query/rule/RuleSemanticQuery.java | 6 +++--- .../chat/responder/parse/EntityInfoParseResponder.java | 2 +- .../java/com/tencent/supersonic/ConfigureDemo.java | 9 +++------ 6 files changed, 20 insertions(+), 11 deletions(-) diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/RuleQueryTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/RuleQueryTool.java index 51ec62936..03c4e4ef5 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/RuleQueryTool.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/RuleQueryTool.java @@ -12,6 +12,8 @@ public class RuleQueryTool extends CommonAgentTool { private List queryModes; + private List queryTypes; + public boolean isContainsAllModel() { return CollectionUtils.isNotEmpty(modelIds) && modelIds.contains(-1L); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/AgentCheckParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/AgentCheckParser.java index 2199bd902..29bff4be0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/AgentCheckParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/AgentCheckParser.java @@ -9,7 +9,9 @@ import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.service.AgentService; +import com.tencent.supersonic.common.pojo.QueryType; import com.tencent.supersonic.common.util.ContextUtils; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; @@ -46,6 +48,14 @@ public class AgentCheckParser implements SemanticParser { && !tool.getQueryModes().contains(query.getQueryMode())) { return true; } + if (CollectionUtils.isNotEmpty(tool.getQueryTypes())) { + if (QueryManager.isTagQuery(query.getQueryMode())) { + return !tool.getQueryTypes().contains(QueryType.TAG.name()); + } + if (QueryManager.isMetricQuery(query.getQueryMode())) { + return !tool.getQueryTypes().contains(QueryType.METRIC.name()); + } + } if (CollectionUtils.isEmpty(tool.getModelIds())) { return true; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryManager.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryManager.java index 3a018455c..45be646b4 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryManager.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/QueryManager.java @@ -79,7 +79,7 @@ public class QueryManager { return ruleQueryMap.get(queryMode) instanceof MetricSemanticQuery; } - public static boolean isEntityQuery(String queryMode) { + public static boolean isTagQuery(String queryMode) { if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) { return false; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java index a62a6850b..d225224bc 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java @@ -51,7 +51,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { } public List match(List candidateElementMatches, - QueryContext queryCtx) { + QueryContext queryCtx) { return queryMatcher.match(candidateElementMatches); } @@ -76,8 +76,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { return; } - if ((QueryManager.isEntityQuery(queryParseInfo.getQueryMode()) - && QueryManager.isEntityQuery(chatParseInfo.getQueryMode())) + if ((QueryManager.isTagQuery(queryParseInfo.getQueryMode()) + && QueryManager.isTagQuery(chatParseInfo.getQueryMode())) || (QueryManager.isMetricQuery(queryParseInfo.getQueryMode()) && QueryManager.isMetricQuery(chatParseInfo.getQueryMode()))) { // inherit date info from context diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java index 103760ad6..168f7327f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/responder/parse/EntityInfoParseResponder.java @@ -35,7 +35,7 @@ public class EntityInfoParseResponder implements ParseResponder { //1. set entity info SemanticService semanticService = ContextUtils.getBean(SemanticService.class); EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryReq.getUser()); - if (QueryManager.isEntityQuery(queryMode) + if (QueryManager.isTagQuery(queryMode) || QueryManager.isMetricQuery(queryMode)) { parseInfo.setEntityInfo(entityInfo); } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java index 5b5cec48b..baa536bf1 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java @@ -27,6 +27,7 @@ import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.service.ConfigService; import com.tencent.supersonic.chat.service.PluginService; import com.tencent.supersonic.chat.service.QueryService; +import com.tencent.supersonic.common.pojo.QueryType; import com.tencent.supersonic.common.pojo.SysParameter; import com.tencent.supersonic.common.service.SysParameterService; import com.tencent.supersonic.common.util.JsonUtil; @@ -263,10 +264,7 @@ public class ConfigureDemo implements ApplicationListener ruleQueryTool.setType(AgentToolType.RULE); ruleQueryTool.setId("0"); ruleQueryTool.setModelIds(Lists.newArrayList(-1L)); - ruleQueryTool.setQueryModes(Lists.newArrayList( - "METRIC_ENTITY", "METRIC_FILTER", "METRIC_GROUPBY", - "METRIC_MODEL", "METRIC_ORDERBY" - )); + ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.METRIC.name())); agentConfig.getTools().add(ruleQueryTool); LLMParserTool llmParserTool = new LLMParserTool(); @@ -292,8 +290,7 @@ public class ConfigureDemo implements ApplicationListener ruleQueryTool.setId("0"); ruleQueryTool.setType(AgentToolType.RULE); ruleQueryTool.setModelIds(Lists.newArrayList(-1L)); - ruleQueryTool.setQueryModes(Lists.newArrayList( - "TAG_DETAIL", "TAG_LIST_FILTER", "TAG_ID")); + ruleQueryTool.setQueryTypes(Lists.newArrayList(QueryType.TAG.name())); agentConfig.getTools().add(ruleQueryTool); LLMParserTool llmParserTool = new LLMParserTool();