mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][chat] Optimize the logic for obtaining the generic thread pool (#1979)
This commit is contained in:
@@ -102,8 +102,13 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
}
|
||||
|
||||
if (parseResp.getSelectedParses().isEmpty()) {
|
||||
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
|
||||
doParse(queryNLReq, parseResp);
|
||||
for (MapModeEnum mode : Lists.newArrayList(MapModeEnum.LOOSE)) {
|
||||
queryNLReq.setMapModeEnum(mode);
|
||||
doParse(queryNLReq, parseResp);
|
||||
if (!parseResp.getSelectedParses().isEmpty()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (parseResp.getSelectedParses().isEmpty()) {
|
||||
|
||||
@@ -12,7 +12,6 @@ import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.common.config.ChatModel;
|
||||
import com.tencent.supersonic.common.config.ThreadPoolConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AuthType;
|
||||
@@ -21,11 +20,13 @@ import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@@ -42,7 +43,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
private ChatModelService chatModelService;
|
||||
|
||||
@Autowired
|
||||
private ThreadPoolConfig threadPoolConfig;
|
||||
@Qualifier("chatExecutor")
|
||||
private ThreadPoolExecutor executor;
|
||||
|
||||
@Override
|
||||
public List<Agent> getAgents(User user, AuthType authType) {
|
||||
@@ -108,7 +110,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
* @param agent
|
||||
*/
|
||||
private void executeAgentExamplesAsync(Agent agent) {
|
||||
threadPoolConfig.getChatExecutor().execute(() -> doExecuteAgentExamples(agent));
|
||||
executor.execute(() -> doExecuteAgentExamples(agent));
|
||||
}
|
||||
|
||||
private synchronized void doExecuteAgentExamples(Agent agent) {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.common.config.ThreadPoolConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
@@ -9,6 +8,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -17,13 +17,15 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> {
|
||||
|
||||
@Autowired
|
||||
protected ThreadPoolConfig threadPoolConfig;
|
||||
@Qualifier("mapExecutor")
|
||||
private ThreadPoolExecutor executor;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||
@@ -72,7 +74,7 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
|
||||
|
||||
protected void executeTasks(List<Callable<Void>> tasks) {
|
||||
try {
|
||||
threadPoolConfig.getMapExecutor().invokeAll(tasks);
|
||||
executor.invokeAll(tasks);
|
||||
for (Callable<Void> future : tasks) {
|
||||
future.call();
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.ThreadPoolConfig;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
||||
@@ -38,8 +37,6 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
|
||||
|
||||
@Autowired
|
||||
private MetaEmbeddingService metaEmbeddingService;
|
||||
@Autowired
|
||||
protected ThreadPoolConfig threadPoolConfig;
|
||||
|
||||
@Override
|
||||
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.server.service.impl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||
import com.tencent.supersonic.common.config.ThreadPoolConfig;
|
||||
import com.tencent.supersonic.common.pojo.ItemDateResp;
|
||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
@@ -52,6 +51,7 @@ import com.tencent.supersonic.headless.server.utils.NameCheckUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.transaction.annotation.Transactional;
|
||||
@@ -69,6 +69,7 @@ import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@@ -93,13 +94,13 @@ public class ModelServiceImpl implements ModelService {
|
||||
|
||||
private final ModelRelaService modelRelaService;
|
||||
|
||||
private final ThreadPoolConfig threadPoolConfig;
|
||||
private final ThreadPoolExecutor executor;
|
||||
|
||||
public ModelServiceImpl(ModelRepository modelRepository, DatabaseService databaseService,
|
||||
@Lazy DimensionService dimensionService, @Lazy MetricService metricService,
|
||||
DomainService domainService, UserService userService, DataSetService dataSetService,
|
||||
DateInfoRepository dateInfoRepository, ModelRelaService modelRelaService,
|
||||
ThreadPoolConfig threadPoolConfig) {
|
||||
@Lazy DimensionService dimensionService, @Lazy MetricService metricService,
|
||||
DomainService domainService, UserService userService, DataSetService dataSetService,
|
||||
DateInfoRepository dateInfoRepository, ModelRelaService modelRelaService,
|
||||
@Qualifier("commonExecutor") ThreadPoolExecutor executor) {
|
||||
this.modelRepository = modelRepository;
|
||||
this.databaseService = databaseService;
|
||||
this.dimensionService = dimensionService;
|
||||
@@ -109,7 +110,7 @@ public class ModelServiceImpl implements ModelService {
|
||||
this.dataSetService = dataSetService;
|
||||
this.dateInfoRepository = dateInfoRepository;
|
||||
this.modelRelaService = modelRelaService;
|
||||
this.threadPoolConfig = threadPoolConfig;
|
||||
this.executor = executor;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -226,13 +227,13 @@ public class ModelServiceImpl implements ModelService {
|
||||
CompletableFuture.allOf(dbSchemas.stream()
|
||||
.map(dbSchema -> CompletableFuture.runAsync(
|
||||
() -> doBuild(modelBuildReq, dbSchema, dbSchemas, modelSchemaMap),
|
||||
threadPoolConfig.getCommonExecutor()))
|
||||
executor))
|
||||
.toArray(CompletableFuture[]::new)).join();
|
||||
return modelSchemaMap;
|
||||
}
|
||||
|
||||
private void doBuild(ModelBuildReq modelBuildReq, DbSchema curSchema, List<DbSchema> dbSchemas,
|
||||
Map<String, ModelSchema> modelSchemaMap) {
|
||||
Map<String, ModelSchema> modelSchemaMap) {
|
||||
ModelSchema modelSchema = new ModelSchema();
|
||||
List<SemanticModeller> semanticModellers = CoreComponentFactory.getSemanticModellers();
|
||||
for (SemanticModeller semanticModeller : semanticModellers) {
|
||||
@@ -250,7 +251,7 @@ public class ModelServiceImpl implements ModelService {
|
||||
}
|
||||
|
||||
private List<DbSchema> convert(Map<String, List<DBColumn>> dbColumnMap,
|
||||
ModelBuildReq modelBuildReq) {
|
||||
ModelBuildReq modelBuildReq) {
|
||||
return dbColumnMap.keySet().stream()
|
||||
.map(key -> convert(modelBuildReq, key, dbColumnMap.get(key)))
|
||||
.collect(Collectors.toList());
|
||||
@@ -405,7 +406,7 @@ public class ModelServiceImpl implements ModelService {
|
||||
}
|
||||
|
||||
public List<ModelResp> getModelRespAuthInheritDomain(User user, Long domainId,
|
||||
AuthType authType) {
|
||||
AuthType authType) {
|
||||
List<Long> domainIds =
|
||||
domainService.getDomainAuthSet(user, authType).stream().filter(domainResp -> {
|
||||
if (domainId == null) {
|
||||
@@ -580,7 +581,7 @@ public class ModelServiceImpl implements ModelService {
|
||||
}
|
||||
|
||||
public static boolean checkDataSetPermission(Set<String> orgIds, User user,
|
||||
ModelResp modelResp) {
|
||||
ModelResp modelResp) {
|
||||
if (checkAdminPermission(orgIds, user, modelResp)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.server.service;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.service.UserService;
|
||||
import com.tencent.supersonic.common.config.ThreadPoolConfig;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
@@ -26,6 +25,7 @@ import org.mockito.Mockito;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ThreadPoolExecutor;
|
||||
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@@ -77,10 +77,10 @@ class ModelServiceImplTest {
|
||||
DateInfoRepository dateInfoRepository = Mockito.mock(DateInfoRepository.class);
|
||||
DataSetService viewService = Mockito.mock(DataSetService.class);
|
||||
ModelRelaService modelRelaService = Mockito.mock(ModelRelaService.class);
|
||||
ThreadPoolConfig threadPoolConfig = Mockito.mock(ThreadPoolConfig.class);
|
||||
ThreadPoolExecutor threadPoolExecutor = Mockito.mock(ThreadPoolExecutor.class);
|
||||
return new ModelServiceImpl(modelRepository, databaseService, dimensionService,
|
||||
metricService, domainService, userService, viewService, dateInfoRepository,
|
||||
modelRelaService, threadPoolConfig);
|
||||
modelRelaService, threadPoolExecutor);
|
||||
}
|
||||
|
||||
private ModelReq mockModelReq() {
|
||||
|
||||
Reference in New Issue
Block a user