mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement][chat]Adopt accept pattern to parsers and executors.
This commit is contained in:
@@ -5,5 +5,7 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
|||||||
|
|
||||||
public interface ChatQueryExecutor {
|
public interface ChatQueryExecutor {
|
||||||
|
|
||||||
|
boolean accept(ExecuteContext executeContext);
|
||||||
|
|
||||||
QueryResult execute(ExecuteContext executeContext);
|
QueryResult execute(ExecuteContext executeContext);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,11 +37,12 @@ public class PlainTextExecutor implements ChatQueryExecutor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public QueryResult execute(ExecuteContext executeContext) {
|
public boolean accept(ExecuteContext executeContext) {
|
||||||
if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) {
|
return "PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode());
|
||||||
return null;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public QueryResult execute(ExecuteContext executeContext) {
|
||||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||||
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||||
ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY);
|
ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY);
|
||||||
|
|||||||
@@ -8,6 +8,11 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
|||||||
|
|
||||||
public class PluginExecutor implements ChatQueryExecutor {
|
public class PluginExecutor implements ChatQueryExecutor {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean accept(ExecuteContext executeContext) {
|
||||||
|
return PluginQueryManager.isPluginQuery(executeContext.getParseInfo().getQueryMode());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public QueryResult execute(ExecuteContext executeContext) {
|
public QueryResult execute(ExecuteContext executeContext) {
|
||||||
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
||||||
|
|||||||
@@ -25,6 +25,11 @@ import java.util.Objects;
|
|||||||
|
|
||||||
public class SqlExecutor implements ChatQueryExecutor {
|
public class SqlExecutor implements ChatQueryExecutor {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean accept(ExecuteContext executeContext) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
@SneakyThrows
|
@SneakyThrows
|
||||||
@Override
|
@Override
|
||||||
public QueryResult execute(ExecuteContext executeContext) {
|
public QueryResult execute(ExecuteContext executeContext) {
|
||||||
|
|||||||
@@ -4,5 +4,7 @@ import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
|||||||
|
|
||||||
public interface ChatQueryParser {
|
public interface ChatQueryParser {
|
||||||
|
|
||||||
|
boolean accept(ParseContext parseContext);
|
||||||
|
|
||||||
void parse(ParseContext parseContext);
|
void parse(ParseContext parseContext);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,12 +14,12 @@ public class NL2PluginParser implements ChatQueryParser {
|
|||||||
private final List<PluginRecognizer> pluginRecognizers =
|
private final List<PluginRecognizer> pluginRecognizers =
|
||||||
ComponentFactory.getPluginRecognizers();
|
ComponentFactory.getPluginRecognizers();
|
||||||
|
|
||||||
|
public boolean accept(ParseContext parseContext) {
|
||||||
|
return parseContext.getAgent().containsPluginTool();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ParseContext parseContext) {
|
public void parse(ParseContext parseContext) {
|
||||||
if (!parseContext.getAgent().containsPluginTool()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||||
pluginRecognizer.recognize(parseContext);
|
pluginRecognizer.recognize(parseContext);
|
||||||
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||||
|
|||||||
@@ -73,12 +73,12 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
.build());
|
.build());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean accept(ParseContext parseContext) {
|
||||||
|
return parseContext.enableNL2SQL();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ParseContext parseContext) {
|
public void parse(ParseContext parseContext) {
|
||||||
if (!parseContext.enableNL2SQL()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// first go with rule-based parsers unless the user has already selected one parse.
|
// first go with rule-based parsers unless the user has already selected one parse.
|
||||||
if (Objects.isNull(parseContext.getRequest().getSelectedParse())) {
|
if (Objects.isNull(parseContext.getRequest().getSelectedParse())) {
|
||||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||||
|
|||||||
@@ -6,12 +6,12 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|||||||
|
|
||||||
public class PlainTextParser implements ChatQueryParser {
|
public class PlainTextParser implements ChatQueryParser {
|
||||||
|
|
||||||
|
public boolean accept(ParseContext parseContext) {
|
||||||
|
return !parseContext.getAgent().containsAnyTool();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void parse(ParseContext parseContext) {
|
public void parse(ParseContext parseContext) {
|
||||||
if (parseContext.getAgent().containsAnyTool()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||||
parseInfo.setQueryMode("PLAIN_TEXT");
|
parseInfo.setQueryMode("PLAIN_TEXT");
|
||||||
parseInfo.setId(1);
|
parseInfo.setId(1);
|
||||||
|
|||||||
@@ -95,7 +95,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
|
ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId));
|
||||||
chatQueryParsers.forEach(p -> p.parse(parseContext));
|
for (ChatQueryParser parser : chatQueryParsers) {
|
||||||
|
if (parser.accept(parseContext)) {
|
||||||
|
parser.parse(parseContext);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||||
if (processor.accept(parseContext)) {
|
if (processor.accept(parseContext)) {
|
||||||
@@ -116,9 +120,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
QueryResult queryResult = new QueryResult();
|
QueryResult queryResult = new QueryResult();
|
||||||
ExecuteContext executeContext = buildExecuteContext(chatExecuteReq);
|
ExecuteContext executeContext = buildExecuteContext(chatExecuteReq);
|
||||||
for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) {
|
for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) {
|
||||||
queryResult = chatQueryExecutor.execute(executeContext);
|
if (chatQueryExecutor.accept(executeContext)) {
|
||||||
if (queryResult != null) {
|
queryResult = chatQueryExecutor.execute(executeContext);
|
||||||
break;
|
if (queryResult != null) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,12 +25,12 @@ public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModel
|
|||||||
@Override
|
@Override
|
||||||
public List<ChatModel> getChatModels(User user) {
|
public List<ChatModel> getChatModels(User user) {
|
||||||
return list().stream().map(this::convert).filter(chatModel -> {
|
return list().stream().map(this::convert).filter(chatModel -> {
|
||||||
if (chatModel.isPublic() || user.isSuperAdmin()
|
if (chatModel.isPublic() || user.isSuperAdmin()
|
||||||
|| chatModel.getCreatedBy().equals(user.getName())
|
|| chatModel.getCreatedBy().equals(user.getName())
|
||||||
|| chatModel.getViewers().contains(user.getName())) {
|
|| chatModel.getViewers().contains(user.getName())) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}).collect(Collectors.toList());
|
}).collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ public class ChatModelController {
|
|||||||
|
|
||||||
@RequestMapping("/getModelList")
|
@RequestMapping("/getModelList")
|
||||||
public List<ChatModel> getModelList(HttpServletRequest httpServletRequest,
|
public List<ChatModel> getModelList(HttpServletRequest httpServletRequest,
|
||||||
HttpServletResponse httpServletResponse) {
|
HttpServletResponse httpServletResponse) {
|
||||||
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||||
return chatModelService.getChatModels(user);
|
return chatModelService.getChatModels(user);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -96,8 +96,7 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
|
|||||||
databaseResp.setHasEditPermission(true);
|
databaseResp.setHasEditPermission(true);
|
||||||
databaseResp.setHasUsePermission(true);
|
databaseResp.setHasUsePermission(true);
|
||||||
}
|
}
|
||||||
if (databaseResp.getViewers().contains(user.getName())
|
if (databaseResp.getViewers().contains(user.getName()) || databaseResp.isPublic()) {
|
||||||
|| databaseResp.isPublic()) {
|
|
||||||
databaseResp.setHasUsePermission(true);
|
databaseResp.setHasUsePermission(true);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user