(improvement)(headless&chat)Execute against SemanticLayerService instead of ChatQueryService in chat-server and rename several classes by the way.

This commit is contained in:
jerryjzhang
2024-07-06 23:32:59 +08:00
parent 6db6aaf98d
commit e0e77a3b64
26 changed files with 185 additions and 176 deletions

View File

@@ -5,25 +5,32 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.ResultFormatter; import com.tencent.supersonic.chat.server.util.ResultFormatter;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService; import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
import lombok.SneakyThrows; import lombok.SneakyThrows;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import java.util.ArrayList;
import java.util.Date; import java.util.Date;
import java.util.List;
import java.util.Map;
public class SqlExecutor implements ChatExecutor { public class SqlExecutor implements ChatExecutor {
@SneakyThrows @SneakyThrows
@Override @Override
public QueryResult execute(ChatExecuteContext chatExecuteContext) { public QueryResult execute(ChatExecuteContext chatExecuteContext) {
ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteContext); QueryResult queryResult = doExecute(chatExecuteContext);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
QueryResult queryResult = chatQueryService.performExecution(executeQueryReq);
if (queryResult != null) { if (queryResult != null) {
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(), String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
queryResult.getQueryResults()); queryResult.getQueryResults());
@@ -48,16 +55,43 @@ public class SqlExecutor implements ChatExecutor {
return queryResult; return queryResult;
} }
private ExecuteQueryReq buildExecuteReq(ChatExecuteContext chatExecuteContext) { @SneakyThrows
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo(); private QueryResult doExecute(ChatExecuteContext chatExecuteContext) {
return ExecuteQueryReq.builder() SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class);
.queryId(chatExecuteContext.getQueryId()) ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
.chatId(chatExecuteContext.getChatId())
.queryText(chatExecuteContext.getQueryText()) ChatContext chatCtx = chatContextService.getOrCreateContext(chatExecuteContext.getChatId());
.parseInfo(parseInfo)
.saveAnswer(chatExecuteContext.isSaveAnswer()) QuerySqlReq sqlReq = QuerySqlReq.builder()
.user(chatExecuteContext.getUser()) .sql(chatExecuteContext.getParseInfo().getSqlInfo().getCorrectS2SQL())
.build(); .build();
sqlReq.setSqlInfo(chatExecuteContext.getParseInfo().getSqlInfo());
sqlReq.setDataSetId(chatExecuteContext.getParseInfo().getDataSetId());
long startTime = System.currentTimeMillis();
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, chatExecuteContext.getUser());
QueryResult queryResult = new QueryResult();
if (queryResp != null) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
: queryResp.getResultList();
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
queryResult.setQuerySql(queryResp.getSql());
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(columns);
queryResult.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
queryResult.setQueryState(QueryState.SUCCESS);
chatCtx.setParseInfo(chatExecuteContext.getParseInfo());
chatContextService.updateContext(chatCtx);
} else {
queryResult.setQueryState(QueryState.INVALID);
queryResult.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
}
return queryResult;
} }
public String buildSchemaStr(SemanticParseInfo parseInfo) { public String buildSchemaStr(SemanticParseInfo parseInfo) {

View File

@@ -15,7 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService; import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
@@ -69,11 +69,11 @@ public class NL2SQLParser implements ChatParser {
} }
processMultiTurn(chatParseContext); processMultiTurn(chatParseContext);
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
addDynamicExemplars(chatParseContext.getAgent().getId(), queryReq); addDynamicExemplars(chatParseContext.getAgent().getId(), queryTextReq);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq); ParseResp text2SqlParseResp = chatQueryService.performParsing(queryTextReq);
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) { if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} }
@@ -149,8 +149,8 @@ public class NL2SQLParser implements ChatParser {
// derive mapping result of current question and parsing result of last question. // derive mapping result of current question and parsing result of last question.
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp currentMapResult = chatQueryService.performMapping(queryReq); MapResp currentMapResult = chatQueryService.performMapping(queryTextReq);
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1); List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1);
if (historyParseResults.size() == 0) { if (historyParseResults.size() == 0) {
@@ -168,7 +168,7 @@ public class NL2SQLParser implements ChatParser {
.curtSchema(curtMapStr) .curtSchema(curtMapStr)
.histSchema(histMapStr) .histSchema(histMapStr)
.histSQL(histSQL) .histSQL(histSQL)
.llmConfig(queryReq.getLlmConfig()) .llmConfig(queryTextReq.getLlmConfig())
.build()); .build());
chatParseContext.setQueryText(rewrittenQuery); chatParseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
@@ -225,13 +225,13 @@ public class NL2SQLParser implements ChatParser {
return contextualList; return contextualList;
} }
private void addDynamicExemplars(Integer agentId, QueryReq queryReq) { private void addDynamicExemplars(Integer agentId, QueryTextReq queryTextReq) {
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class); ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId); String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName, List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
queryReq.getQueryText(), 5); queryTextReq.getQueryText(), 5);
queryReq.getDynamicExemplars().addAll(exemplars); queryTextReq.getDynamicExemplars().addAll(exemplars);
} }
@Builder @Builder

View File

@@ -10,7 +10,7 @@ import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
@@ -76,7 +76,7 @@ public class ChatQueryController {
} }
@PostMapping("queryContext") @PostMapping("queryContext")
public Object queryContext(@RequestBody QueryReq queryCtx, public Object queryContext(@RequestBody QueryTextReq queryCtx,
HttpServletRequest request, HttpServletResponse response) { HttpServletRequest request, HttpServletResponse response) {
queryCtx.setUser(UserHolder.findUser(request, response)); queryCtx.setUser(UserHolder.findUser(request, response));
return chatService.queryContext(queryCtx.getChatId()); return chatService.queryContext(queryCtx.getChatId());

View File

@@ -21,7 +21,7 @@ import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
@@ -60,8 +60,8 @@ public class ChatServiceImpl implements ChatService {
if (!agent.enableSearch()) { if (!agent.enableSearch()) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
return retrieveService.retrieve(queryReq); return retrieveService.retrieve(queryTextReq);
} }
@Override @Override
@@ -137,8 +137,8 @@ public class ChatServiceImpl implements ChatService {
} }
private void supplyMapInfo(ChatParseContext chatParseContext) { private void supplyMapInfo(ChatParseContext chatParseContext) {
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); QueryTextReq queryTextReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp mapResp = chatQueryService.performMapping(queryReq); MapResp mapResp = chatQueryService.performMapping(queryTextReq);
chatParseContext.setMapInfo(mapResp.getMapInfo()); chatParseContext.setMapInfo(mapResp.getMapInfo());
} }

View File

@@ -4,35 +4,35 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import org.apache.commons.collections.MapUtils; import org.apache.commons.collections.MapUtils;
import java.util.Objects; import java.util.Objects;
public class QueryReqConverter { public class QueryReqConverter {
public static QueryReq buildText2SqlQueryReq(ChatParseContext chatParseContext) { public static QueryTextReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
QueryReq queryReq = new QueryReq(); QueryTextReq queryTextReq = new QueryTextReq();
BeanMapper.mapper(chatParseContext, queryReq); BeanMapper.mapper(chatParseContext, queryTextReq);
Agent agent = chatParseContext.getAgent(); Agent agent = chatParseContext.getAgent();
if (agent == null) { if (agent == null) {
return queryReq; return queryTextReq;
} }
if (agent.containsLLMParserTool() && agent.containsRuleTool()) { if (agent.containsLLMParserTool() && agent.containsRuleTool()) {
queryReq.setText2SQLType(Text2SQLType.RULE_AND_LLM); queryTextReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
} else if (agent.containsLLMParserTool()) { } else if (agent.containsLLMParserTool()) {
queryReq.setText2SQLType(Text2SQLType.ONLY_LLM); queryTextReq.setText2SQLType(Text2SQLType.ONLY_LLM);
} else if (agent.containsRuleTool()) { } else if (agent.containsRuleTool()) {
queryReq.setText2SQLType(Text2SQLType.ONLY_RULE); queryTextReq.setText2SQLType(Text2SQLType.ONLY_RULE);
} }
queryReq.setDataSetIds(agent.getDataSetIds()); queryTextReq.setDataSetIds(agent.getDataSetIds());
if (Objects.nonNull(queryReq.getMapInfo()) if (Objects.nonNull(queryTextReq.getMapInfo())
&& MapUtils.isNotEmpty(queryReq.getMapInfo().getDataSetElementMatches())) { && MapUtils.isNotEmpty(queryTextReq.getMapInfo().getDataSetElementMatches())) {
queryReq.setMapInfo(queryReq.getMapInfo()); queryTextReq.setMapInfo(queryTextReq.getMapInfo());
} }
queryReq.setLlmConfig(agent.getLlmConfig()); queryTextReq.setLlmConfig(agent.getLlmConfig());
queryReq.setPromptConfig(agent.getPromptConfig()); queryTextReq.setPromptConfig(agent.getPromptConfig());
return queryReq; return queryTextReq;
} }
} }

View File

@@ -2,12 +2,14 @@ package com.tencent.supersonic.common.config;
import com.tencent.supersonic.common.util.AESEncryptionUtil; import com.tencent.supersonic.common.util.AESEncryptionUtil;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import java.io.Serializable; import java.io.Serializable;
@Data @Data
@Builder
@AllArgsConstructor @AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
public class ChatModelConfig implements Serializable { public class ChatModelConfig implements Serializable {

View File

@@ -4,6 +4,7 @@ public enum WorkflowState {
MAPPING, MAPPING,
PARSING, PARSING,
CORRECTING, CORRECTING,
TRANSLATING,
PROCESSING, PROCESSING,
FINISHED FINISHED
} }

View File

@@ -1,6 +1,9 @@
package com.tencent.supersonic.headless.api.pojo.request; package com.tencent.supersonic.headless.api.pojo.request;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString; import lombok.ToString;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
@@ -8,8 +11,13 @@ import java.util.Objects;
@Data @Data
@ToString @ToString
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class QuerySqlReq extends SemanticQueryReq { public class QuerySqlReq extends SemanticQueryReq {
private String sql; private String sql;
private Integer limit = 1000; private Integer limit = 1000;
@Override @Override

View File

@@ -155,7 +155,7 @@ public class QueryStructReq extends SemanticQueryReq {
String sql = null; String sql = null;
try { try {
sql = buildSql(this, isBizName); sql = buildSql(this, isBizName);
} catch (Exception e) { } catch (JSQLParserException e) {
log.error("buildSql error", e); log.error("buildSql error", e);
} }
@@ -164,7 +164,7 @@ public class QueryStructReq extends SemanticQueryReq {
result.setDataSetId(this.getDataSetId()); result.setDataSetId(this.getDataSetId());
result.setModelIds(this.getModelIdSet()); result.setModelIds(this.getModelIdSet());
result.setParams(new ArrayList<>()); result.setParams(new ArrayList<>());
result.setSqlInfo(this.getSqlInfo()); result.getSqlInfo().setCorrectS2SQL(sql);
return result; return result;
} }

View File

@@ -16,7 +16,7 @@ import java.util.List;
import java.util.Set; import java.util.Set;
@Data @Data
public class QueryReq { public class QueryTextReq {
private String queryText; private String queryText;
private Integer chatId; private Integer chatId;
private Set<Long> dataSetIds = Sets.newHashSet(); private Set<Long> dataSetIds = Sets.newHashSet();

View File

@@ -12,7 +12,7 @@ import lombok.ToString;
@Builder @Builder
@AllArgsConstructor @AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
public class ExplainSqlReq<T> { public class TranslateSqlReq<T> {
private QueryMethod queryTypeEnum; private QueryMethod queryTypeEnum;

View File

@@ -13,7 +13,7 @@ import java.io.Serializable;
@Builder @Builder
@AllArgsConstructor @AllArgsConstructor
@NoArgsConstructor @NoArgsConstructor
public class ExplainResp implements Serializable { public class TranslateResp implements Serializable {
private String sql; private String sql;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.chat.parser.llm; package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
@@ -22,7 +22,7 @@ public class ParseResult {
private LLMResp llmResp; private LLMResp llmResp;
private QueryReq request; private QueryTextReq request;
private List<LLMReq.ElementValue> linkingValues; private List<LLMReq.ElementValue> linkingValues;
} }

View File

@@ -1,8 +1,7 @@
package com.tencent.supersonic.headless.server.facade.rest; package com.tencent.supersonic.headless.server.facade.rest;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService; import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.RetrieveService; import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
@@ -27,35 +26,27 @@ public class ChatQueryApiController {
private RetrieveService retrieveService; private RetrieveService retrieveService;
@PostMapping("/chat/search") @PostMapping("/chat/search")
public Object search(@RequestBody QueryReq queryReq, public Object search(@RequestBody QueryTextReq queryTextReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
queryReq.setUser(UserHolder.findUser(request, response)); queryTextReq.setUser(UserHolder.findUser(request, response));
return retrieveService.retrieve(queryReq); return retrieveService.retrieve(queryTextReq);
} }
@PostMapping("/chat/map") @PostMapping("/chat/map")
public MapResp map(@RequestBody QueryReq queryReq, public MapResp map(@RequestBody QueryTextReq queryTextReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
queryReq.setUser(UserHolder.findUser(request, response)); queryTextReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.performMapping(queryReq); return chatQueryService.performMapping(queryTextReq);
} }
@PostMapping("/chat/parse") @PostMapping("/chat/parse")
public Object parse(@RequestBody QueryReq queryReq, public Object parse(@RequestBody QueryTextReq queryTextReq,
HttpServletRequest request, HttpServletRequest request,
HttpServletResponse response) throws Exception { HttpServletResponse response) throws Exception {
queryReq.setUser(UserHolder.findUser(request, response)); queryTextReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.performParsing(queryReq); return chatQueryService.performParsing(queryTextReq);
}
@PostMapping("/chat/execute")
public Object execute(@RequestBody ExecuteQueryReq executeQueryReq,
HttpServletRequest request,
HttpServletResponse response) throws Exception {
executeQueryReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.performExecution(executeQueryReq);
} }
} }

View File

@@ -7,7 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq; import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp; import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
@@ -19,12 +19,13 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
*/ */
public interface ChatQueryService { public interface ChatQueryService {
MapResp performMapping(QueryReq queryReq); MapResp performMapping(QueryTextReq queryTextReq);
MapInfoResp map(QueryMapReq queryMapReq); MapInfoResp map(QueryMapReq queryMapReq);
ParseResp performParsing(QueryReq queryReq); ParseResp performParsing(QueryTextReq queryTextReq);
@Deprecated
QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception; QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception;
SemanticParseInfo queryContext(Integer chatId); SemanticParseInfo queryContext(Integer chatId);

View File

@@ -1,12 +1,12 @@
package com.tencent.supersonic.headless.server.facade.service; package com.tencent.supersonic.headless.server.facade.service;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult; import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import java.util.List; import java.util.List;
public interface RetrieveService { public interface RetrieveService {
List<SearchResult> retrieve(QueryReq queryCtx); List<SearchResult> retrieve(QueryTextReq queryCtx);
} }

View File

@@ -4,10 +4,10 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo; import com.tencent.supersonic.headless.api.pojo.EntityInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq; import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp; import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -24,7 +24,7 @@ public interface SemanticLayerService {
SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user); SemanticQueryResp queryDimValue(QueryDimValueReq queryDimValueReq, User user);
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception; <T> TranslateResp translate(TranslateSqlReq<T> translateSqlReq, User user) throws Exception;
EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user); EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user);

View File

@@ -29,19 +29,19 @@ import com.tencent.supersonic.headless.api.pojo.enums.CostType;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq; import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq; import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetMapInfo; import com.tencent.supersonic.headless.api.pojo.response.DataSetMapInfo;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp; import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp; import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -63,13 +63,13 @@ import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService; import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService; import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.utils.WorkflowEngine;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO; import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.headless.server.pojo.MetaFilter; import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.utils.ComponentFactory; import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.web.service.ChatContextService; import com.tencent.supersonic.headless.server.web.service.ChatContextService;
import com.tencent.supersonic.headless.server.web.service.DataSetService; import com.tencent.supersonic.headless.server.web.service.DataSetService;
import com.tencent.supersonic.headless.server.web.service.SchemaService; import com.tencent.supersonic.headless.server.web.service.SchemaService;
import com.tencent.supersonic.headless.server.web.service.WorkflowService;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue; import net.sf.jsqlparser.expression.LongValue;
@@ -115,45 +115,45 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Autowired @Autowired
private DataSetService dataSetService; private DataSetService dataSetService;
@Autowired @Autowired
private WorkflowService workflowService; private WorkflowEngine workflowEngine;
@Override @Override
public MapResp performMapping(QueryReq queryReq) { public MapResp performMapping(QueryTextReq queryTextReq) {
MapResp mapResp = new MapResp(); MapResp mapResp = new MapResp();
QueryContext queryCtx = buildQueryContext(queryReq); QueryContext queryCtx = buildQueryContext(queryTextReq);
ComponentFactory.getSchemaMappers().forEach(mapper -> { ComponentFactory.getSchemaMappers().forEach(mapper -> {
mapper.map(queryCtx); mapper.map(queryCtx);
}); });
SchemaMapInfo mapInfo = queryCtx.getMapInfo(); SchemaMapInfo mapInfo = queryCtx.getMapInfo();
mapResp.setMapInfo(mapInfo); mapResp.setMapInfo(mapInfo);
mapResp.setQueryText(queryReq.getQueryText()); mapResp.setQueryText(queryTextReq.getQueryText());
return mapResp; return mapResp;
} }
@Override @Override
public MapInfoResp map(QueryMapReq queryMapReq) { public MapInfoResp map(QueryMapReq queryMapReq) {
QueryReq queryReq = new QueryReq(); QueryTextReq queryTextReq = new QueryTextReq();
BeanUtils.copyProperties(queryMapReq, queryReq); BeanUtils.copyProperties(queryMapReq, queryTextReq);
List<DataSetResp> dataSets = dataSetService.getDataSets(queryMapReq.getDataSetNames(), queryMapReq.getUser()); List<DataSetResp> dataSets = dataSetService.getDataSets(queryMapReq.getDataSetNames(), queryMapReq.getUser());
Set<Long> dataSetIds = dataSets.stream().map(SchemaItem::getId).collect(Collectors.toSet()); Set<Long> dataSetIds = dataSets.stream().map(SchemaItem::getId).collect(Collectors.toSet());
queryReq.setDataSetIds(dataSetIds); queryTextReq.setDataSetIds(dataSetIds);
MapResp mapResp = performMapping(queryReq); MapResp mapResp = performMapping(queryTextReq);
dataSetIds.retainAll(mapResp.getMapInfo().getDataSetElementMatches().keySet()); dataSetIds.retainAll(mapResp.getMapInfo().getDataSetElementMatches().keySet());
return convert(mapResp, queryMapReq.getTopN(), dataSetIds); return convert(mapResp, queryMapReq.getTopN(), dataSetIds);
} }
@Override @Override
public ParseResp performParsing(QueryReq queryReq) { public ParseResp performParsing(QueryTextReq queryTextReq) {
ParseResp parseResult = new ParseResp(queryReq.getChatId(), queryReq.getQueryText()); ParseResp parseResult = new ParseResp(queryTextReq.getChatId(), queryTextReq.getQueryText());
// build queryContext and chatContext // build queryContext and chatContext
QueryContext queryCtx = buildQueryContext(queryReq); QueryContext queryCtx = buildQueryContext(queryTextReq);
// in order to support multi-turn conversation, chat context is needed // in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId()); ChatContext chatCtx = chatContextService.getOrCreateContext(queryTextReq.getChatId());
workflowService.startWorkflow(queryCtx, chatCtx, parseResult); workflowEngine.startWorkflow(queryCtx, chatCtx, parseResult);
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream() List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList()); .map(SemanticQuery::getParseInfo).collect(Collectors.toList());
@@ -161,25 +161,26 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return parseResult; return parseResult;
} }
public QueryContext buildQueryContext(QueryReq queryReq) { public QueryContext buildQueryContext(QueryTextReq queryTextReq) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema(); SemanticSchema semanticSchema = schemaService.getSemanticSchema();
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(); Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds();
QueryContext queryCtx = QueryContext.builder() QueryContext queryCtx = QueryContext.builder()
.queryFilters(queryReq.getQueryFilters()) .queryFilters(queryTextReq.getQueryFilters())
.semanticSchema(semanticSchema) .semanticSchema(semanticSchema)
.candidateQueries(new ArrayList<>()) .candidateQueries(new ArrayList<>())
.mapInfo(new SchemaMapInfo()) .mapInfo(new SchemaMapInfo())
.modelIdToDataSetIds(modelIdToDataSetIds) .modelIdToDataSetIds(modelIdToDataSetIds)
.text2SQLType(queryReq.getText2SQLType()) .text2SQLType(queryTextReq.getText2SQLType())
.mapModeEnum(queryReq.getMapModeEnum()) .mapModeEnum(queryTextReq.getMapModeEnum())
.dataSetIds(queryReq.getDataSetIds()) .dataSetIds(queryTextReq.getDataSetIds())
.build(); .build();
BeanUtils.copyProperties(queryReq, queryCtx); BeanUtils.copyProperties(queryTextReq, queryCtx);
return queryCtx; return queryCtx;
} }
@Override @Override
@Deprecated
public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception { public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception {
List<StatisticsDO> timeCostDOList = new ArrayList<>(); List<StatisticsDO> timeCostDOList = new ArrayList<>();
SemanticParseInfo parseInfo = queryReq.getParseInfo(); SemanticParseInfo parseInfo = queryReq.getParseInfo();
@@ -263,9 +264,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql); parseInfo.getSqlInfo().setCorrectS2SQL(correctorSql);
semanticQuery.setParseInfo(parseInfo); semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
ExplainSqlReq<Object> explainSqlReq = ExplainSqlReq.builder().queryReq(semanticQueryReq) TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq)
.queryTypeEnum(QueryMethod.SQL).build(); .queryTypeEnum(QueryMethod.SQL).build();
ExplainResp explain = semanticLayerService.explain(explainSqlReq, user); TranslateResp explain = semanticLayerService.translate(translateSqlReq, user);
if (StringUtils.isNotBlank(explain.getSql())) { if (StringUtils.isNotBlank(explain.getSql())) {
parseInfo.getSqlInfo().setQuerySQL(explain.getSql()); parseInfo.getSqlInfo().setQuerySQL(explain.getSql());
parseInfo.getSqlInfo().setSourceId(explain.getSourceId()); parseInfo.getSqlInfo().setSourceId(explain.getSourceId());
@@ -277,9 +278,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
validFilter(semanticQuery.getParseInfo().getMetricFilters()); validFilter(semanticQuery.getParseInfo().getMetricFilters());
//init s2sql //init s2sql
semanticQuery.initS2Sql(semanticSchema, user); semanticQuery.initS2Sql(semanticSchema, user);
QueryReq queryReq = new QueryReq(); QueryTextReq queryTextReq = new QueryTextReq();
queryReq.setQueryFilters(new QueryFilters()); queryTextReq.setQueryFilters(new QueryFilters());
queryReq.setUser(user); queryTextReq.setUser(user);
} }
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user); QueryResult queryResult = doExecution(semanticQueryReq, semanticQuery.getParseInfo(), user);

View File

@@ -8,7 +8,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult; import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.chat.QueryContext; import com.tencent.supersonic.headless.chat.QueryContext;
@@ -64,9 +64,9 @@ public class RetrieveServiceImpl implements RetrieveService {
@Autowired @Autowired
private SearchMatchStrategy searchMatchStrategy; private SearchMatchStrategy searchMatchStrategy;
@Override @Override
public List<SearchResult> retrieve(QueryReq queryReq) { public List<SearchResult> retrieve(QueryTextReq queryTextReq) {
String queryText = queryReq.getQueryText(); String queryText = queryTextReq.getQueryText();
// 1.get meta info // 1.get meta info
SemanticSchema semanticSchemaDb = schemaService.getSemanticSchema(); SemanticSchema semanticSchemaDb = schemaService.getSemanticSchema();
List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics(); List<SchemaElement> metricsDb = semanticSchemaDb.getMetrics();
@@ -76,10 +76,10 @@ public class RetrieveServiceImpl implements RetrieveService {
// 2.detect by segment // 2.detect by segment
List<S2Term> originals = knowledgeBaseService.getTerms(queryText, modelIdToDataSetIds); List<S2Term> originals = knowledgeBaseService.getTerms(queryText, modelIdToDataSetIds);
log.debug("hanlp parse result: {}", originals); log.debug("hanlp parse result: {}", originals);
Set<Long> dataSetIds = queryReq.getDataSetIds(); Set<Long> dataSetIds = queryTextReq.getDataSetIds();
QueryContext queryContext = new QueryContext(); QueryContext queryContext = new QueryContext();
BeanUtils.copyProperties(queryReq, queryContext); BeanUtils.copyProperties(queryTextReq, queryContext);
queryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds()); queryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds());
Map<MatchText, List<HanlpMapResult>> regTextMap = Map<MatchText, List<HanlpMapResult>> regTextMap =
@@ -100,12 +100,12 @@ public class RetrieveServiceImpl implements RetrieveService {
return Lists.newArrayList(); return Lists.newArrayList();
} }
Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry = mostSimilarSearchResult.get(); Map.Entry<MatchText, List<HanlpMapResult>> searchTextEntry = mostSimilarSearchResult.get();
log.debug("searchTextEntry:{},queryReq:{}", searchTextEntry, queryReq); log.debug("searchTextEntry:{},queryTextReq:{}", searchTextEntry, queryTextReq);
Set<SearchResult> searchResults = new LinkedHashSet(); Set<SearchResult> searchResults = new LinkedHashSet();
DataSetInfoStat dataSetInfoStat = NatureHelper.getDataSetStat(originals); DataSetInfoStat dataSetInfoStat = NatureHelper.getDataSetStat(originals);
List<Long> possibleDataSets = getPossibleDataSets(queryReq, originals, dataSetInfoStat, dataSetIds); List<Long> possibleDataSets = getPossibleDataSets(queryTextReq, originals, dataSetInfoStat, dataSetIds);
// 5.1 priority dimension metric // 5.1 priority dimension metric
boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleDataSets), dataSetIdToName, boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleDataSets), dataSetIdToName,
@@ -120,14 +120,14 @@ public class RetrieveServiceImpl implements RetrieveService {
Set<SearchResult> searchResultSet = searchDimensionValue(metricsDb, dataSetIdToName, Set<SearchResult> searchResultSet = searchDimensionValue(metricsDb, dataSetIdToName,
dataSetInfoStat.getMetricDataSetCount(), existMetricAndDimension, dataSetInfoStat.getMetricDataSetCount(), existMetricAndDimension,
matchText, natureToNameMap, natureToNameEntry, queryReq.getQueryFilters()); matchText, natureToNameMap, natureToNameEntry, queryTextReq.getQueryFilters());
searchResults.addAll(searchResultSet); searchResults.addAll(searchResultSet);
} }
return searchResults.stream().limit(RESULT_SIZE).collect(Collectors.toList()); return searchResults.stream().limit(RESULT_SIZE).collect(Collectors.toList());
} }
private List<Long> getPossibleDataSets(QueryReq queryCtx, List<S2Term> originals, private List<Long> getPossibleDataSets(QueryTextReq queryCtx, List<S2Term> originals,
DataSetInfoStat dataSetInfoStat, Set<Long> dataSetIds) { DataSetInfoStat dataSetInfoStat, Set<Long> dataSetIds) {
if (CollectionUtils.isNotEmpty(dataSetIds)) { if (CollectionUtils.isNotEmpty(dataSetIds)) {
return new ArrayList<>(dataSetIds); return new ArrayList<>(dataSetIds);

View File

@@ -19,7 +19,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig; import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq; import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDimValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
@@ -28,7 +28,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq; import com.tencent.supersonic.headless.api.pojo.request.SchemaFilterReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.ItemResp; import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
@@ -125,7 +125,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
StatUtils.get().setUseResultCache(false); StatUtils.get().setUseResultCache(false);
//3 query //3 query
QueryStatement queryStatement = buildQueryStatement(queryReq, user); QueryStatement queryStatement = buildQueryStatement(queryReq, user);
SemanticQueryResp result = query(queryStatement); SemanticQueryResp result = doQuery(queryStatement);
//4 reset cache and set stateInfo //4 reset cache and set stateInfo
Boolean setCacheSuccess = queryCache.put(cacheKey, result); Boolean setCacheSuccess = queryCache.put(cacheKey, result);
if (setCacheSuccess) { if (setCacheSuccess) {
@@ -228,8 +228,8 @@ public class S2SemanticLayerService implements SemanticLayerService {
} }
@Override @Override
public <T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception { public <T> TranslateResp translate(TranslateSqlReq<T> translateSqlReq, User user) throws Exception {
T queryReq = explainSqlReq.getQueryReq(); T queryReq = translateSqlReq.getQueryReq();
QueryStatement queryStatement = buildQueryStatement((SemanticQueryReq) queryReq, user); QueryStatement queryStatement = buildQueryStatement((SemanticQueryReq) queryReq, user);
semanticTranslator.translate(queryStatement); semanticTranslator.translate(queryStatement);
@@ -239,7 +239,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
sql = queryStatement.getSql(); sql = queryStatement.getSql();
sorceId = queryStatement.getSourceId(); sorceId = queryStatement.getSourceId();
} }
return ExplainResp.builder().sql(sql).sourceId(sorceId).build(); return TranslateResp.builder().sql(sql).sourceId(sorceId).build();
} }
public List<ItemResp> getDomainDataSetTree() { public List<ItemResp> getDomainDataSetTree() {
@@ -268,7 +268,7 @@ public class S2SemanticLayerService implements SemanticLayerService {
return querySqlReq; return querySqlReq;
} }
private SemanticQueryResp query(QueryStatement queryStatement) { private SemanticQueryResp doQuery(QueryStatement queryStatement) {
SemanticQueryResp semanticQueryResp = null; SemanticQueryResp semanticQueryResp = null;
try { try {
//1 translate //1 translate

View File

@@ -4,9 +4,9 @@ import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo; import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq; import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.chat.ChatContext; import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.QueryContext; import com.tencent.supersonic.headless.chat.QueryContext;
@@ -24,8 +24,8 @@ import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
/** /**
* SqlInfoProcessor adds S2SQL to the parsing results so that * SqlInfoProcessor adds intermediate S2SQL and final SQL to the parsing results
* technical users could verify SQL by themselves. * so that technical users could verify SQL by themselves.
**/ **/
@Slf4j @Slf4j
public class SqlInfoProcessor implements ResultProcessor { public class SqlInfoProcessor implements ResultProcessor {
@@ -66,9 +66,9 @@ public class SqlInfoProcessor implements ResultProcessor {
semanticQuery.setParseInfo(parseInfo); semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class); SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class);
ExplainSqlReq<Object> explainSqlReq = ExplainSqlReq.builder().queryReq(semanticQueryReq) TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq)
.queryTypeEnum(QueryMethod.SQL).build(); .queryTypeEnum(QueryMethod.SQL).build();
ExplainResp explain = queryService.explain(explainSqlReq, queryContext.getUser()); TranslateResp explain = queryService.translate(translateSqlReq, queryContext.getUser());
String querySql = explain.getSql(); String querySql = explain.getSql();
if (StringUtils.isBlank(querySql)) { if (StringUtils.isBlank(querySql)) {
return; return;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.server.web.service.impl; package com.tencent.supersonic.headless.server.utils;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState; import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
@@ -11,8 +11,6 @@ import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.query.SemanticQuery; import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.server.processor.ResultProcessor; import com.tencent.supersonic.headless.server.processor.ResultProcessor;
import com.tencent.supersonic.headless.server.web.service.WorkflowService;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.MapUtils; import org.apache.commons.collections.MapUtils;
@@ -23,7 +21,7 @@ import java.util.Objects;
@Service @Service
@Slf4j @Slf4j
public class WorkflowServiceImpl implements WorkflowService { public class WorkflowEngine {
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers(); private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers(); private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors(); private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();

View File

@@ -1,9 +0,0 @@
package com.tencent.supersonic.headless.server.web.service;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.QueryContext;
public interface WorkflowService {
void startWorkflow(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult);
}

View File

@@ -22,9 +22,9 @@ import java.util.Arrays;
import java.util.List; import java.util.List;
/** /**
* QueryReqBuilderTest * QueryTextReqBuilderTest
*/ */
class QueryReqBuilderTest { class QueryTextReqBuilderTest {
@Test @Test
void buildS2SQLReq() { void buildS2SQLReq() {

View File

@@ -5,21 +5,19 @@ import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricGroupByQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricGroupByQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricTopNQuery; import com.tencent.supersonic.headless.chat.query.rule.metric.MetricTopNQuery;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import org.junit.Assert; import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.text.DateFormat; import java.text.DateFormat;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE; import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM; import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
@@ -27,6 +25,8 @@ import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
public class MetricTest extends BaseTest { public class MetricTest extends BaseTest {
private int chatId = 10;
@Test @Test
public void testMetricFilter() throws Exception { public void testMetricFilter() throws Exception {
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId); QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);
@@ -51,16 +51,6 @@ public class MetricTest extends BaseTest {
assertQueryResult(expectedResult, actualResult); assertQueryResult(expectedResult, actualResult);
} }
@Test
public void testMetricFilterWithAgent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getMetricAgent().getId());
Assert.assertNotNull(parseResp.getSelectedParses());
List<String> queryModes = parseResp.getSelectedParses().stream()
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
Assert.assertTrue(queryModes.contains("METRIC_FILTER"));
}
@Test @Test
public void testMetricDomain() throws Exception { public void testMetricDomain() throws Exception {
QueryResult actualResult = submitNewChat("超音数的访问次数", DataUtils.metricAgentId); QueryResult actualResult = submitNewChat("超音数的访问次数", DataUtils.metricAgentId);
@@ -80,18 +70,9 @@ public class MetricTest extends BaseTest {
assertQueryResult(expectedResult, actualResult); assertQueryResult(expectedResult, actualResult);
} }
@Test
public void testMetricModelWithAgent() {
//agent only support METRIC_ENTITY, METRIC_FILTER
ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getMetricAgent().getId());
List<String> queryModes = parseResp.getSelectedParses().stream()
.map(SemanticParseInfo::getQueryMode).collect(Collectors.toList());
Assert.assertTrue(queryModes.contains("METRIC_MODEL"));
}
@Test @Test
public void testMetricGroupBy() throws Exception { public void testMetricGroupBy() throws Exception {
QueryResult actualResult = submitNewChat("超音数各部门的访问次数", DataUtils.metricAgentId); QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数", DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult(); QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo(); SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -137,6 +118,7 @@ public class MetricTest extends BaseTest {
} }
@Test @Test
@Order(3)
public void testMetricTopN() throws Exception { public void testMetricTopN() throws Exception {
QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.metricAgentId); QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.metricAgentId);
@@ -160,7 +142,7 @@ public class MetricTest extends BaseTest {
@Test @Test
public void testMetricGroupBySum() throws Exception { public void testMetricGroupBySum() throws Exception {
QueryResult actualResult = submitNewChat("超音数各部门的访问次数总和", DataUtils.metricAgentId); QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数总和", DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult(); QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo(); SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo); expectedResult.setChatContext(expectedParseInfo);

View File

@@ -3,10 +3,10 @@ package com.tencent.supersonic.headless;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod; import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq; import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder; import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@@ -16,16 +16,16 @@ import java.util.Arrays;
import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
public class ExplainTest extends BaseTest { public class TranslateTest extends BaseTest {
@Test @Test
public void testSqlExplain() throws Exception { public void testSqlExplain() throws Exception {
String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "; String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
ExplainSqlReq<QuerySqlReq> explainSqlReq = ExplainSqlReq.<QuerySqlReq>builder() TranslateSqlReq<QuerySqlReq> translateSqlReq = TranslateSqlReq.<QuerySqlReq>builder()
.queryTypeEnum(QueryMethod.SQL) .queryTypeEnum(QueryMethod.SQL)
.queryReq(QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView())) .queryReq(QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()))
.build(); .build();
ExplainResp explain = semanticLayerService.explain(explainSqlReq, User.getFakeUser()); TranslateResp explain = semanticLayerService.translate(translateSqlReq, User.getFakeUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getSql()); assertNotNull(explain.getSql());
assertTrue(explain.getSql().contains("department")); assertTrue(explain.getSql().contains("department"));
@@ -35,11 +35,11 @@ public class ExplainTest extends BaseTest {
@Test @Test
public void testStructExplain() throws Exception { public void testStructExplain() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department")); QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
ExplainSqlReq<QueryStructReq> explainSqlReq = ExplainSqlReq.<QueryStructReq>builder() TranslateSqlReq<QueryStructReq> translateSqlReq = TranslateSqlReq.<QueryStructReq>builder()
.queryTypeEnum(QueryMethod.STRUCT) .queryTypeEnum(QueryMethod.STRUCT)
.queryReq(queryStructReq) .queryReq(queryStructReq)
.build(); .build();
ExplainResp explain = semanticLayerService.explain(explainSqlReq, User.getFakeUser()); TranslateResp explain = semanticLayerService.translate(translateSqlReq, User.getFakeUser());
assertNotNull(explain); assertNotNull(explain);
assertNotNull(explain.getSql()); assertNotNull(explain.getSql());
assertTrue(explain.getSql().contains("department")); assertTrue(explain.getSql().contains("department"));