mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-20 06:34:55 +00:00
(improvement)(Chat) Simplify processor in Headless and Chat (#822)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -70,6 +70,11 @@ public class Agent extends RecordInfo {
|
||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
|
||||
}
|
||||
|
||||
public boolean containsNL2SQLTool() {
|
||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM))
|
||||
|| !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE));
|
||||
}
|
||||
|
||||
public Set<Long> getDataSetIds() {
|
||||
Set<Long> dataSetIds = getDataSetIds(null);
|
||||
if (containsAllModel(dataSetIds)) {
|
||||
|
||||
@@ -15,6 +15,7 @@ public class PluginExecutor implements ChatExecutor {
|
||||
return null;
|
||||
}
|
||||
PluginSemanticQuery query = PluginQueryManager.getPluginQuery(parseInfo.getQueryMode());
|
||||
query.setParseInfo(parseInfo);
|
||||
return query.build();
|
||||
}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import java.util.List;
|
||||
|
||||
public class Text2PluginParser implements ChatParser {
|
||||
public class NL2PluginParser implements ChatParser {
|
||||
|
||||
private final List<PluginRecognizer> pluginRecognizers = ComponentFactory.getPluginRecognizers();
|
||||
|
||||
@@ -7,10 +7,13 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
||||
|
||||
public class Text2SqlParser implements ChatParser {
|
||||
public class NL2SQLParser implements ChatParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!chatParseContext.enableNL2SQL()) {
|
||||
return;
|
||||
}
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
||||
@@ -12,7 +12,6 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
@@ -54,9 +53,7 @@ public class PluginManager {
|
||||
|
||||
public static List<Plugin> getPluginAgentCanSupport(ChatParseContext chatParseContext) {
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(chatParseContext.getAgentId());
|
||||
|
||||
Agent agent = chatParseContext.getAgent();
|
||||
List<Plugin> plugins = pluginService.getPluginList();
|
||||
if (Objects.isNull(agent)) {
|
||||
return plugins;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import lombok.Data;
|
||||
@@ -9,9 +10,16 @@ import lombok.Data;
|
||||
public class ChatParseContext {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Integer agentId;
|
||||
private Agent agent;
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
|
||||
public boolean enableNL2SQL() {
|
||||
if (agent == null) {
|
||||
return true;
|
||||
}
|
||||
return agent.containsNL2SQLTool();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* EntityInfoProcessor fills core attributes of an entity so that
|
||||
* users get to know which entity is parsed out.
|
||||
*/
|
||||
public class EntityInfoProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
||||
if (CollectionUtils.isEmpty(selectedParses)) {
|
||||
return;
|
||||
}
|
||||
selectedParses.forEach(parseInfo -> {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
if (QueryManager.containsRuleQuery(queryMode)) {
|
||||
return;
|
||||
}
|
||||
//1. set entity info
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId());
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, chatParseContext.getUser());
|
||||
if (QueryManager.isTagQuery(queryMode)
|
||||
|| QueryManager.isMetricQuery(queryMode)) {
|
||||
parseInfo.setEntityInfo(entityInfo);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* RespBuildProcessor fill response object with parsing results.
|
||||
**/
|
||||
@Slf4j
|
||||
public class RespBuildProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
parseResp.setChatId(chatParseContext.getChatId());
|
||||
parseResp.setQueryText(chatParseContext.getQueryText());
|
||||
List<SemanticParseInfo> parseInfos = parseResp.getSelectedParses();
|
||||
if (CollectionUtils.isNotEmpty(parseInfos)) {
|
||||
parseResp.setState(ParseResp.ParseState.COMPLETED);
|
||||
} else {
|
||||
parseResp.setState(ParseResp.ParseState.FAILED);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
* TimeCostProcessor adds time cost of parsing.
|
||||
**/
|
||||
@Slf4j
|
||||
public class TimeCostProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
|
||||
parseResp.getParseTimeCost().setParseTime(
|
||||
System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatParser;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
|
||||
@@ -18,10 +19,12 @@ import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
@@ -73,7 +76,7 @@ public class ChatServiceImpl implements ChatService {
|
||||
|
||||
@Override
|
||||
public ParseResp performParsing(ChatParseReq chatParseReq) {
|
||||
ParseResp parseResp = new ParseResp();
|
||||
ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText());
|
||||
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
|
||||
for (ChatParser chatParser : chatParsers) {
|
||||
chatParser.parse(chatParseContext, parseResp);
|
||||
@@ -102,6 +105,9 @@ public class ChatServiceImpl implements ChatService {
|
||||
private ChatParseContext buildParseContext(ChatParseReq chatParseReq) {
|
||||
ChatParseContext chatParseContext = new ChatParseContext();
|
||||
BeanMapper.mapper(chatParseReq, chatParseContext);
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||
chatParseContext.setAgent(agent);
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp mapResp = chatQueryService.performMapping(queryReq);
|
||||
chatParseContext.setMapInfo(mapResp.getMapInfo());
|
||||
|
||||
@@ -2,9 +2,7 @@ package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
|
||||
public class QueryReqConverter {
|
||||
@@ -12,11 +10,7 @@ public class QueryReqConverter {
|
||||
public static QueryReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
|
||||
QueryReq queryReq = new QueryReq();
|
||||
BeanMapper.mapper(chatParseContext, queryReq);
|
||||
if (chatParseContext.getAgentId() == null) {
|
||||
return queryReq;
|
||||
}
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(chatParseContext.getAgentId());
|
||||
Agent agent = chatParseContext.getAgent();
|
||||
if (agent == null) {
|
||||
return queryReq;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user