[improvement][chat]Adopt accept pattern to parsers and executors.
Some checks are pending
supersonic CentOS CI / build (21) (push) Waiting to run
supersonic mac CI / build (21) (push) Waiting to run
supersonic ubuntu CI / build (21) (push) Waiting to run
supersonic windows CI / build (21) (push) Waiting to run

This commit is contained in:
jerryjzhang
2025-03-11 00:27:06 +08:00
parent 93d585c0d5
commit b58e041e8d
12 changed files with 49 additions and 29 deletions

View File

@@ -5,5 +5,7 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
public interface ChatQueryExecutor { public interface ChatQueryExecutor {
boolean accept(ExecuteContext executeContext);
QueryResult execute(ExecuteContext executeContext); QueryResult execute(ExecuteContext executeContext);
} }

View File

@@ -37,11 +37,12 @@ public class PlainTextExecutor implements ChatQueryExecutor {
} }
@Override @Override
public QueryResult execute(ExecuteContext executeContext) { public boolean accept(ExecuteContext executeContext) {
if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) { return "PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode());
return null; }
}
@Override
public QueryResult execute(ExecuteContext executeContext) {
AgentService agentService = ContextUtils.getBean(AgentService.class); AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId()); Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY); ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY);

View File

@@ -8,6 +8,11 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
public class PluginExecutor implements ChatQueryExecutor { public class PluginExecutor implements ChatQueryExecutor {
@Override
public boolean accept(ExecuteContext executeContext) {
return PluginQueryManager.isPluginQuery(executeContext.getParseInfo().getQueryMode());
}
@Override @Override
public QueryResult execute(ExecuteContext executeContext) { public QueryResult execute(ExecuteContext executeContext) {
SemanticParseInfo parseInfo = executeContext.getParseInfo(); SemanticParseInfo parseInfo = executeContext.getParseInfo();

View File

@@ -25,6 +25,11 @@ import java.util.Objects;
public class SqlExecutor implements ChatQueryExecutor { public class SqlExecutor implements ChatQueryExecutor {
@Override
public boolean accept(ExecuteContext executeContext) {
return true;
}
@SneakyThrows @SneakyThrows
@Override @Override
public QueryResult execute(ExecuteContext executeContext) { public QueryResult execute(ExecuteContext executeContext) {

View File

@@ -4,5 +4,7 @@ import com.tencent.supersonic.chat.server.pojo.ParseContext;
public interface ChatQueryParser { public interface ChatQueryParser {
boolean accept(ParseContext parseContext);
void parse(ParseContext parseContext); void parse(ParseContext parseContext);
} }

View File

@@ -14,12 +14,12 @@ public class NL2PluginParser implements ChatQueryParser {
private final List<PluginRecognizer> pluginRecognizers = private final List<PluginRecognizer> pluginRecognizers =
ComponentFactory.getPluginRecognizers(); ComponentFactory.getPluginRecognizers();
public boolean accept(ParseContext parseContext) {
return parseContext.getAgent().containsPluginTool();
}
@Override @Override
public void parse(ParseContext parseContext) { public void parse(ParseContext parseContext) {
if (!parseContext.getAgent().containsPluginTool()) {
return;
}
pluginRecognizers.forEach(pluginRecognizer -> { pluginRecognizers.forEach(pluginRecognizer -> {
pluginRecognizer.recognize(parseContext); pluginRecognizer.recognize(parseContext);
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(), log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),

View File

@@ -73,12 +73,12 @@ public class NL2SQLParser implements ChatQueryParser {
.build()); .build());
} }
public boolean accept(ParseContext parseContext) {
return parseContext.enableNL2SQL();
}
@Override @Override
public void parse(ParseContext parseContext) { public void parse(ParseContext parseContext) {
if (!parseContext.enableNL2SQL()) {
return;
}
// first go with rule-based parsers unless the user has already selected one parse. // first go with rule-based parsers unless the user has already selected one parse.
if (Objects.isNull(parseContext.getRequest().getSelectedParse())) { if (Objects.isNull(parseContext.getRequest().getSelectedParse())) {
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);

View File

@@ -6,12 +6,12 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public class PlainTextParser implements ChatQueryParser { public class PlainTextParser implements ChatQueryParser {
public boolean accept(ParseContext parseContext) {
return !parseContext.getAgent().containsAnyTool();
}
@Override @Override
public void parse(ParseContext parseContext) { public void parse(ParseContext parseContext) {
if (parseContext.getAgent().containsAnyTool()) {
return;
}
SemanticParseInfo parseInfo = new SemanticParseInfo(); SemanticParseInfo parseInfo = new SemanticParseInfo();
parseInfo.setQueryMode("PLAIN_TEXT"); parseInfo.setQueryMode("PLAIN_TEXT");
parseInfo.setId(1); parseInfo.setId(1);

View File

@@ -95,7 +95,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId)); ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
chatQueryParsers.forEach(p -> p.parse(parseContext)); for (ChatQueryParser parser : chatQueryParsers) {
if (parser.accept(parseContext)) {
parser.parse(parseContext);
}
}
for (ParseResultProcessor processor : parseResultProcessors) { for (ParseResultProcessor processor : parseResultProcessors) {
if (processor.accept(parseContext)) { if (processor.accept(parseContext)) {
@@ -116,9 +120,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
QueryResult queryResult = new QueryResult(); QueryResult queryResult = new QueryResult();
ExecuteContext executeContext = buildExecuteContext(chatExecuteReq); ExecuteContext executeContext = buildExecuteContext(chatExecuteReq);
for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) { for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) {
queryResult = chatQueryExecutor.execute(executeContext); if (chatQueryExecutor.accept(executeContext)) {
if (queryResult != null) { queryResult = chatQueryExecutor.execute(executeContext);
break; if (queryResult != null) {
break;
}
} }
} }

View File

@@ -25,12 +25,12 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
@Override @Override
public List<ChatModel> getChatModels(User user) { public List<ChatModel> getChatModels(User user) {
return list().stream().map(this::convert).filter(chatModel -> { return list().stream().map(this::convert).filter(chatModel -> {
if (chatModel.isPublic() || user.isSuperAdmin() if (chatModel.isPublic() || user.isSuperAdmin()
|| chatModel.getCreatedBy().equals(user.getName()) || chatModel.getCreatedBy().equals(user.getName())
|| chatModel.getViewers().contains(user.getName())) { || chatModel.getViewers().contains(user.getName())) {
return true; return true;
} }
return false; return false;
}).collect(Collectors.toList()); }).collect(Collectors.toList());
} }

View File

@@ -53,7 +53,7 @@ public class ChatModelController {
@RequestMapping("/getModelList") @RequestMapping("/getModelList")
public List<ChatModel> getModelList(HttpServletRequest httpServletRequest, public List<ChatModel> getModelList(HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) { HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse); User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
return chatModelService.getChatModels(user); return chatModelService.getChatModels(user);
} }

View File

@@ -96,8 +96,7 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
databaseResp.setHasEditPermission(true); databaseResp.setHasEditPermission(true);
databaseResp.setHasUsePermission(true); databaseResp.setHasUsePermission(true);
} }
if (databaseResp.getViewers().contains(user.getName()) if (databaseResp.getViewers().contains(user.getName()) || databaseResp.isPublic()) {
|| databaseResp.isPublic()) {
databaseResp.setHasUsePermission(true); databaseResp.setHasUsePermission(true);
} }
}); });