From a23d1071a31acaf42b8e8b47e9d51000912884ea Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Thu, 26 Dec 2024 23:37:55 +0800 Subject: [PATCH] [improvement][chat] Optimize the logic for obtaining the generic thread pool (#1979) --- .../chat/server/parser/NL2SQLParser.java | 9 +++++-- .../server/service/impl/AgentServiceImpl.java | 8 +++--- .../chat/mapper/BaseMatchStrategy.java | 8 +++--- .../chat/mapper/EmbeddingMatchStrategy.java | 3 --- .../server/service/impl/ModelServiceImpl.java | 25 ++++++++++--------- .../server/service/ModelServiceImplTest.java | 6 ++--- 6 files changed, 33 insertions(+), 26 deletions(-) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index ec88c2291..f3dee66f3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -102,8 +102,13 @@ public class NL2SQLParser implements ChatQueryParser { } if (parseResp.getSelectedParses().isEmpty()) { - queryNLReq.setMapModeEnum(MapModeEnum.LOOSE); - doParse(queryNLReq, parseResp); + for (MapModeEnum mode : Lists.newArrayList(MapModeEnum.LOOSE)) { + queryNLReq.setMapModeEnum(mode); + doParse(queryNLReq, parseResp); + if (!parseResp.getSelectedParses().isEmpty()) { + break; + } + } } if (parseResp.getSelectedParses().isEmpty()) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index f20fb7eeb..211728906 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -12,7 +12,6 @@ import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatQueryService; import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.common.config.ChatModel; -import com.tencent.supersonic.common.config.ThreadPoolConfig; import com.tencent.supersonic.common.pojo.ChatApp; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.enums.AuthType; @@ -21,11 +20,13 @@ import com.tencent.supersonic.common.util.JsonUtil; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; import java.util.List; import java.util.Objects; +import java.util.concurrent.ThreadPoolExecutor; import java.util.stream.Collectors; @Slf4j @@ -42,7 +43,8 @@ public class AgentServiceImpl extends ServiceImpl implem private ChatModelService chatModelService; @Autowired - private ThreadPoolConfig threadPoolConfig; + @Qualifier("chatExecutor") + private ThreadPoolExecutor executor; @Override public List getAgents(User user, AuthType authType) { @@ -108,7 +110,7 @@ public class AgentServiceImpl extends ServiceImpl implem * @param agent */ private void executeAgentExamplesAsync(Agent agent) { - threadPoolConfig.getChatExecutor().execute(() -> doExecuteAgentExamples(agent)); + executor.execute(() -> doExecuteAgentExamples(agent)); } private synchronized void doExecuteAgentExamples(Agent agent) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java index 0ba4ccc1d..21a6d2338 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.headless.chat.mapper; -import com.tencent.supersonic.common.config.ThreadPoolConfig; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.chat.ChatQueryContext; @@ -9,6 +8,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Service; import java.util.HashMap; @@ -17,13 +17,15 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.concurrent.Callable; +import java.util.concurrent.ThreadPoolExecutor; @Service @Slf4j public abstract class BaseMatchStrategy implements MatchStrategy { @Autowired - protected ThreadPoolConfig threadPoolConfig; + @Qualifier("mapExecutor") + private ThreadPoolExecutor executor; @Override public Map> match(ChatQueryContext chatQueryContext, List terms, @@ -72,7 +74,7 @@ public abstract class BaseMatchStrategy implements MatchStr protected void executeTasks(List> tasks) { try { - threadPoolConfig.getMapExecutor().invokeAll(tasks); + executor.invokeAll(tasks); for (Callable future : tasks) { future.call(); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java index 958e9353d..e6e52c5d4 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.headless.chat.mapper; import com.google.common.collect.Lists; -import com.tencent.supersonic.common.config.ThreadPoolConfig; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService; @@ -38,8 +37,6 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy @Autowired private MetaEmbeddingService metaEmbeddingService; - @Autowired - protected ThreadPoolConfig threadPoolConfig; @Override public List detectByBatch(ChatQueryContext chatQueryContext, diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java index 0dd3d62cf..8a0888b0e 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelServiceImpl.java @@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.server.service.impl; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.service.UserService; -import com.tencent.supersonic.common.config.ThreadPoolConfig; import com.tencent.supersonic.common.pojo.ItemDateResp; import com.tencent.supersonic.common.pojo.ModelRela; import com.tencent.supersonic.common.pojo.User; @@ -52,6 +51,7 @@ import com.tencent.supersonic.headless.server.utils.NameCheckUtils; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -69,6 +69,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadPoolExecutor; import java.util.stream.Collectors; @Service @@ -93,13 +94,13 @@ public class ModelServiceImpl implements ModelService { private final ModelRelaService modelRelaService; - private final ThreadPoolConfig threadPoolConfig; + private final ThreadPoolExecutor executor; public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService, - @Lazy DimensionService dimensionService, @Lazy MetricService metricService, - DomainService domainService, UserService userService, DataSetService dataSetService, - DateInfoRepository dateInfoRepository, ModelRelaService modelRelaService, - ThreadPoolConfig threadPoolConfig) { + @Lazy DimensionService dimensionService, @Lazy MetricService metricService, + DomainService domainService, UserService userService, DataSetService dataSetService, + DateInfoRepository dateInfoRepository, ModelRelaService modelRelaService, + @Qualifier("commonExecutor") ThreadPoolExecutor executor) { this.modelRepository = modelRepository; this.databaseService = databaseService; this.dimensionService = dimensionService; @@ -109,7 +110,7 @@ public class ModelServiceImpl implements ModelService { this.dataSetService = dataSetService; this.dateInfoRepository = dateInfoRepository; this.modelRelaService = modelRelaService; - this.threadPoolConfig = threadPoolConfig; + this.executor = executor; } @Override @@ -226,13 +227,13 @@ public class ModelServiceImpl implements ModelService { CompletableFuture.allOf(dbSchemas.stream() .map(dbSchema -> CompletableFuture.runAsync( () -> doBuild(modelBuildReq, dbSchema, dbSchemas, modelSchemaMap), - threadPoolConfig.getCommonExecutor())) + executor)) .toArray(CompletableFuture[]::new)).join(); return modelSchemaMap; } private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List dbSchemas, - Map modelSchemaMap) { + Map modelSchemaMap) { ModelSchema modelSchema = new ModelSchema(); List semanticModellers = CoreComponentFactory.getSemanticModellers(); for (SemanticModeller semanticModeller : semanticModellers) { @@ -250,7 +251,7 @@ public class ModelServiceImpl implements ModelService { } private List convert(Map> dbColumnMap, - ModelBuildReq modelBuildReq) { + ModelBuildReq modelBuildReq) { return dbColumnMap.keySet().stream() .map(key -> convert(modelBuildReq, key, dbColumnMap.get(key))) .collect(Collectors.toList()); @@ -405,7 +406,7 @@ public class ModelServiceImpl implements ModelService { } public List getModelRespAuthInheritDomain(User user, Long domainId, - AuthType authType) { + AuthType authType) { List domainIds = domainService.getDomainAuthSet(user, authType).stream().filter(domainResp -> { if (domainId == null) { @@ -580,7 +581,7 @@ public class ModelServiceImpl implements ModelService { } public static boolean checkDataSetPermission(Set orgIds, User user, - ModelResp modelResp) { + ModelResp modelResp) { if (checkAdminPermission(orgIds, user, modelResp)) { return true; } diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java index 2a8dfcb49..5a25efab5 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java @@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.server.service; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.service.UserService; -import com.tencent.supersonic.common.config.ThreadPoolConfig; import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.StatusEnum; @@ -26,6 +25,7 @@ import org.mockito.Mockito; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.ThreadPoolExecutor; import static org.mockito.Mockito.when; @@ -77,10 +77,10 @@ class ModelServiceImplTest { DateInfoRepository dateInfoRepository = Mockito.mock(DateInfoRepository.class); DataSetService viewService = Mockito.mock(DataSetService.class); ModelRelaService modelRelaService = Mockito.mock(ModelRelaService.class); - ThreadPoolConfig threadPoolConfig = Mockito.mock(ThreadPoolConfig.class); + ThreadPoolExecutor threadPoolExecutor = Mockito.mock(ThreadPoolExecutor.class); return new ModelServiceImpl(modelRepository, databaseService, dimensionService, metricService, domainService, userService, viewService, dateInfoRepository, - modelRelaService, threadPoolConfig); + modelRelaService, threadPoolExecutor); } private ModelReq mockModelReq() {