From 8a8370164fad328715dcadbfe34edd64510c4ec9 Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Wed, 13 Mar 2024 10:22:16 +0800 Subject: [PATCH] (improvement)(Chat) Determine whether to enable LLM based on agent information (#810) Co-authored-by: jolunoluo --- .../tencent/supersonic/chat/server/agent/Agent.java | 4 ++++ .../chat/server/service/impl/ChatServiceImpl.java | 3 +++ .../headless/api/pojo/request/QueryReq.java | 1 + .../server/service/impl/ChatQueryServiceImpl.java | 12 +----------- 4 files changed, 9 insertions(+), 11 deletions(-) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index 0b81f4057..830cf3a7f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -66,6 +66,10 @@ public class Agent extends RecordInfo { .collect(Collectors.toList()); } + public boolean containsLLMParserTool() { + return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM)); + } + public Set getDataSetIds() { Set dataSetIds = getDataSetIds(null); if (containsAllModel(dataSetIds)) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java index 16f3f8d17..a8bfe2495 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java @@ -106,6 +106,9 @@ public class ChatServiceImpl implements ChatService { if (agent == null) { return queryReq; } + if (agent.containsLLMParserTool()) { + queryReq.setEnableLLM(true); + } queryReq.setDataSetIds(agent.getDataSetIds()); return queryReq; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java index 8f3f8e9c9..fa4831305 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryReq.java @@ -14,4 +14,5 @@ public class QueryReq { private User user; private QueryFilters queryFilters; private boolean saveAnswer = true; + private boolean enableLLM; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java index 278b012a2..03be4bdfa 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java @@ -75,7 +75,6 @@ import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import java.util.ArrayList; @@ -102,9 +101,6 @@ public class ChatQueryServiceImpl implements ChatQueryService { @Autowired private DataSetService dataSetService; - @Value("${time.threshold: 100}") - private Integer timeThreshold; - private List schemaMappers = ComponentFactory.getSchemaMappers(); private List semanticParsers = ComponentFactory.getSemanticParsers(); private List semanticCorrectors = ComponentFactory.getSemanticCorrectors(); @@ -118,22 +114,15 @@ public class ChatQueryServiceImpl implements ChatQueryService { // in order to support multi-turn conversation, chat context is needed ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId()); - List timeCostDOList = new ArrayList<>(); // 1. mapper schemaMappers.forEach(mapper -> { - long startTime = System.currentTimeMillis(); mapper.map(queryCtx); - timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) - .interfaceName(mapper.getClass().getSimpleName()).type(CostType.MAPPER.getType()).build()); }); // 2. parser semanticParsers.forEach(parser -> { - long startTime = System.currentTimeMillis(); parser.parse(queryCtx, chatCtx); - timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) - .interfaceName(parser.getClass().getSimpleName()).type(CostType.PARSER.getType()).build()); log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); }); @@ -168,6 +157,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { .candidateQueries(new ArrayList<>()) .mapInfo(new SchemaMapInfo()) .modelIdToDataSetIds(modelIdToDataSetIds) + .enableLLM(queryReq.isEnableLLM()) .build(); BeanUtils.copyProperties(queryReq, queryCtx); return queryCtx;