From 83b80e35f04b7370a757375ffaf935b937c09193 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Fri, 26 Apr 2024 17:05:47 +0800 Subject: [PATCH] (improvement)(Headless) Add workflow handling to the performParsing stage. (#955) --- .../api/pojo/enums/WorkflowState.java | 9 ++ .../headless/core/pojo/QueryContext.java | 3 + .../server/service/WorkflowService.java | 9 ++ .../service/impl/ChatQueryServiceImpl.java | 48 +--------- .../service/impl/WorkflowServiceImpl.java | 93 +++++++++++++++++++ 5 files changed, 119 insertions(+), 43 deletions(-) create mode 100644 headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/WorkflowState.java create mode 100644 headless/server/src/main/java/com/tencent/supersonic/headless/server/service/WorkflowService.java create mode 100644 headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/WorkflowServiceImpl.java diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/WorkflowState.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/WorkflowState.java new file mode 100644 index 000000000..a49067e42 --- /dev/null +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/WorkflowState.java @@ -0,0 +1,9 @@ +package com.tencent.supersonic.headless.api.pojo.enums; + +public enum WorkflowState { + MAPPING, + PARSING, + CORRECTING, + PROCESSING, + FINISHED +} \ No newline at end of file diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java index c7492ce1d..0ace7ecf1 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/QueryContext.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; +import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.core.chat.query.SemanticQuery; import com.tencent.supersonic.headless.core.config.OptimizationConfig; @@ -42,6 +43,8 @@ public class QueryContext { private MapModeEnum mapModeEnum = MapModeEnum.STRICT; @JsonIgnore private SemanticSchema semanticSchema; + @JsonIgnore + private WorkflowState workflowState; public List getCandidateQueries() { OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/WorkflowService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/WorkflowService.java new file mode 100644 index 000000000..2edfafe49 --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/WorkflowService.java @@ -0,0 +1,9 @@ +package com.tencent.supersonic.headless.server.service; + +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; +import com.tencent.supersonic.headless.core.pojo.ChatContext; +import com.tencent.supersonic.headless.core.pojo.QueryContext; + +public interface WorkflowService { + void startWorkflow(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult); +} \ No newline at end of file diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java index 67cad3174..7b9d8c7b0 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java @@ -41,26 +41,22 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp; import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector; import com.tencent.supersonic.headless.core.chat.corrector.SchemaCorrector; -import com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector; import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult; import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService; import com.tencent.supersonic.headless.core.chat.knowledge.SearchService; import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper; import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper; -import com.tencent.supersonic.headless.core.chat.mapper.SchemaMapper; -import com.tencent.supersonic.headless.core.chat.parser.SemanticParser; import com.tencent.supersonic.headless.core.chat.query.QueryManager; import com.tencent.supersonic.headless.core.chat.query.SemanticQuery; import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery; -import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.headless.core.pojo.ChatContext; import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO; -import com.tencent.supersonic.headless.server.processor.ResultProcessor; import com.tencent.supersonic.headless.server.service.ChatContextService; import com.tencent.supersonic.headless.server.service.ChatQueryService; import com.tencent.supersonic.headless.server.service.DataSetService; import com.tencent.supersonic.headless.server.service.QueryService; +import com.tencent.supersonic.headless.server.service.WorkflowService; import com.tencent.supersonic.headless.server.utils.ComponentFactory; import lombok.extern.slf4j.Slf4j; import net.sf.jsqlparser.expression.Expression; @@ -76,7 +72,6 @@ import net.sf.jsqlparser.expression.operators.relational.MinorThan; import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.schema.Column; import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.collections.MapUtils; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.springframework.beans.BeanUtils; @@ -106,17 +101,14 @@ public class ChatQueryServiceImpl implements ChatQueryService { private QueryService queryService; @Autowired private DataSetService dataSetService; - - private List schemaMappers = ComponentFactory.getSchemaMappers(); - private List semanticParsers = ComponentFactory.getSemanticParsers(); - private List semanticCorrectors = ComponentFactory.getSemanticCorrectors(); - private List resultProcessors = ComponentFactory.getResultProcessors(); + @Autowired + private WorkflowService workflowService; @Override public MapResp performMapping(QueryReq queryReq) { MapResp mapResp = new MapResp(); QueryContext queryCtx = buildQueryContext(queryReq); - schemaMappers.forEach(mapper -> { + ComponentFactory.getSchemaMappers().forEach(mapper -> { mapper.map(queryCtx); }); SchemaMapInfo mapInfo = queryCtx.getMapInfo(); @@ -134,38 +126,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { // in order to support multi-turn conversation, chat context is needed ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId()); - // 1. mapper - if (Objects.isNull(queryReq.getMapInfo()) - || MapUtils.isEmpty(queryReq.getMapInfo().getDataSetElementMatches())) { - schemaMappers.forEach(mapper -> { - mapper.map(queryCtx); - }); - } + workflowService.startWorkflow(queryCtx, chatCtx, parseResult); - // 2. parser - semanticParsers.forEach(parser -> { - parser.parse(queryCtx, chatCtx); - log.info("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); - }); - - // 3. corrector - List candidateQueries = queryCtx.getCandidateQueries(); - if (CollectionUtils.isNotEmpty(candidateQueries)) { - for (SemanticQuery semanticQuery : candidateQueries) { - // the rules are not being corrected. - if (semanticQuery instanceof RuleSemanticQuery) { - continue; - } - semanticCorrectors.forEach(corrector -> { - corrector.correct(queryCtx, semanticQuery.getParseInfo()); - }); - } - } - - //4. processor - resultProcessors.forEach(processor -> { - processor.process(parseResult, queryCtx, chatCtx); - }); List parseInfos = queryCtx.getCandidateQueries().stream() .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); parseResult.setSelectedParses(parseInfos); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/WorkflowServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/WorkflowServiceImpl.java new file mode 100644 index 000000000..095c0f8a2 --- /dev/null +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/WorkflowServiceImpl.java @@ -0,0 +1,93 @@ +package com.tencent.supersonic.headless.server.service.impl; + +import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState; +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; +import com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector; +import com.tencent.supersonic.headless.core.chat.mapper.SchemaMapper; +import com.tencent.supersonic.headless.core.chat.parser.SemanticParser; +import com.tencent.supersonic.headless.core.chat.query.SemanticQuery; +import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery; +import com.tencent.supersonic.headless.core.pojo.ChatContext; +import com.tencent.supersonic.headless.core.pojo.QueryContext; +import com.tencent.supersonic.headless.server.processor.ResultProcessor; +import com.tencent.supersonic.headless.server.service.WorkflowService; +import com.tencent.supersonic.headless.server.utils.ComponentFactory; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.CollectionUtils; +import org.apache.commons.collections4.MapUtils; +import org.springframework.stereotype.Service; + +import java.util.List; +import java.util.Objects; + +@Service +@Slf4j +public class WorkflowServiceImpl implements WorkflowService { + private List schemaMappers = ComponentFactory.getSchemaMappers(); + private List semanticParsers = ComponentFactory.getSemanticParsers(); + private List semanticCorrectors = ComponentFactory.getSemanticCorrectors(); + private List resultProcessors = ComponentFactory.getResultProcessors(); + + public void startWorkflow(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) { + queryCtx.setWorkflowState(WorkflowState.MAPPING); + while (queryCtx.getWorkflowState() != WorkflowState.FINISHED) { + switch (queryCtx.getWorkflowState()) { + case MAPPING: + performMapping(queryCtx); + queryCtx.setWorkflowState(WorkflowState.PARSING); + break; + case PARSING: + performParsing(queryCtx, chatCtx); + queryCtx.setWorkflowState(WorkflowState.CORRECTING); + break; + case CORRECTING: + performCorrecting(queryCtx); + queryCtx.setWorkflowState(WorkflowState.PROCESSING); + break; + case PROCESSING: + default: + performProcessing(queryCtx, chatCtx, parseResult); + queryCtx.setWorkflowState(WorkflowState.FINISHED); + break; + } + } + } + + public void performMapping(QueryContext queryCtx) { + if (Objects.isNull(queryCtx.getMapInfo()) + || MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) { + schemaMappers.forEach(mapper -> mapper.map(queryCtx)); + } + } + + public void performParsing(QueryContext queryCtx, ChatContext chatCtx) { + semanticParsers.forEach(parser -> { + parser.parse(queryCtx, chatCtx); + log.info("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); + }); + } + + public void performCorrecting(QueryContext queryCtx) { + List candidateQueries = queryCtx.getCandidateQueries(); + if (CollectionUtils.isNotEmpty(candidateQueries)) { + for (SemanticQuery semanticQuery : candidateQueries) { + if (semanticQuery instanceof RuleSemanticQuery) { + continue; + } + for (SemanticCorrector corrector : semanticCorrectors) { + corrector.correct(queryCtx, semanticQuery.getParseInfo()); + if (!WorkflowState.PARSING.equals(queryCtx.getWorkflowState())) { + break; + } + } + } + } + } + + public void performProcessing(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) { + resultProcessors.forEach(processor -> { + processor.process(parseResult, queryCtx, chatCtx); + }); + } +} \ No newline at end of file