(improvement)(Headless) Add workflow handling to the performParsing stage. (#955)

This commit is contained in:
lexluo09
2024-04-26 17:05:47 +08:00
committed by GitHub
parent 11c2e0505b
commit 83b80e35f0
5 changed files with 119 additions and 43 deletions

View File

@@ -0,0 +1,9 @@
package com.tencent.supersonic.headless.api.pojo.enums;
public enum WorkflowState {
MAPPING,
PARSING,
CORRECTING,
PROCESSING,
FINISHED
}

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.config.OptimizationConfig;
@@ -42,6 +43,8 @@ public class QueryContext {
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
@JsonIgnore
private SemanticSchema semanticSchema;
@JsonIgnore
private WorkflowState workflowState;
public List<SemanticQuery> getCandidateQueries() {
OptimizationConfig optimizationConfig = ContextUtils.getBean(OptimizationConfig.class);

View File

@@ -0,0 +1,9 @@
package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
public interface WorkflowService {
void startWorkflow(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult);
}

View File

@@ -41,26 +41,22 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.core.chat.corrector.GrammarCorrector;
import com.tencent.supersonic.headless.core.chat.corrector.SchemaCorrector;
import com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeService;
import com.tencent.supersonic.headless.core.chat.knowledge.SearchService;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.HanlpHelper;
import com.tencent.supersonic.headless.core.chat.knowledge.helper.NatureHelper;
import com.tencent.supersonic.headless.core.chat.mapper.SchemaMapper;
import com.tencent.supersonic.headless.core.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
import com.tencent.supersonic.headless.server.service.ChatContextService;
import com.tencent.supersonic.headless.server.service.ChatQueryService;
import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.headless.server.service.QueryService;
import com.tencent.supersonic.headless.server.service.WorkflowService;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
@@ -76,7 +72,6 @@ import net.sf.jsqlparser.expression.operators.relational.MinorThan;
import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.schema.Column;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
@@ -106,17 +101,14 @@ public class ChatQueryServiceImpl implements ChatQueryService {
private QueryService queryService;
@Autowired
private DataSetService dataSetService;
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
private List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
@Autowired
private WorkflowService workflowService;
@Override
public MapResp performMapping(QueryReq queryReq) {
MapResp mapResp = new MapResp();
QueryContext queryCtx = buildQueryContext(queryReq);
schemaMappers.forEach(mapper -> {
ComponentFactory.getSchemaMappers().forEach(mapper -> {
mapper.map(queryCtx);
});
SchemaMapInfo mapInfo = queryCtx.getMapInfo();
@@ -134,38 +126,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
// in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatContextService.getOrCreateContext(queryReq.getChatId());
// 1. mapper
if (Objects.isNull(queryReq.getMapInfo())
|| MapUtils.isEmpty(queryReq.getMapInfo().getDataSetElementMatches())) {
schemaMappers.forEach(mapper -> {
mapper.map(queryCtx);
});
}
workflowService.startWorkflow(queryCtx, chatCtx, parseResult);
// 2. parser
semanticParsers.forEach(parser -> {
parser.parse(queryCtx, chatCtx);
log.info("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
// 3. corrector
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) {
// the rules are not being corrected.
if (semanticQuery instanceof RuleSemanticQuery) {
continue;
}
semanticCorrectors.forEach(corrector -> {
corrector.correct(queryCtx, semanticQuery.getParseInfo());
});
}
}
//4. processor
resultProcessors.forEach(processor -> {
processor.process(parseResult, queryCtx, chatCtx);
});
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
parseResult.setSelectedParses(parseInfos);

View File

@@ -0,0 +1,93 @@
package com.tencent.supersonic.headless.server.service.impl;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.core.chat.mapper.SchemaMapper;
import com.tencent.supersonic.headless.core.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.core.pojo.ChatContext;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
import com.tencent.supersonic.headless.server.service.WorkflowService;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.collections4.MapUtils;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.Objects;
@Service
@Slf4j
public class WorkflowServiceImpl implements WorkflowService {
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
private List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
public void startWorkflow(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
queryCtx.setWorkflowState(WorkflowState.MAPPING);
while (queryCtx.getWorkflowState() != WorkflowState.FINISHED) {
switch (queryCtx.getWorkflowState()) {
case MAPPING:
performMapping(queryCtx);
queryCtx.setWorkflowState(WorkflowState.PARSING);
break;
case PARSING:
performParsing(queryCtx, chatCtx);
queryCtx.setWorkflowState(WorkflowState.CORRECTING);
break;
case CORRECTING:
performCorrecting(queryCtx);
queryCtx.setWorkflowState(WorkflowState.PROCESSING);
break;
case PROCESSING:
default:
performProcessing(queryCtx, chatCtx, parseResult);
queryCtx.setWorkflowState(WorkflowState.FINISHED);
break;
}
}
}
public void performMapping(QueryContext queryCtx) {
if (Objects.isNull(queryCtx.getMapInfo())
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) {
schemaMappers.forEach(mapper -> mapper.map(queryCtx));
}
}
public void performParsing(QueryContext queryCtx, ChatContext chatCtx) {
semanticParsers.forEach(parser -> {
parser.parse(queryCtx, chatCtx);
log.info("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
}
public void performCorrecting(QueryContext queryCtx) {
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) {
if (semanticQuery instanceof RuleSemanticQuery) {
continue;
}
for (SemanticCorrector corrector : semanticCorrectors) {
corrector.correct(queryCtx, semanticQuery.getParseInfo());
if (!WorkflowState.PARSING.equals(queryCtx.getWorkflowState())) {
break;
}
}
}
}
}
public void performProcessing(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
resultProcessors.forEach(processor -> {
processor.process(parseResult, queryCtx, chatCtx);
});
}
}