(improvement)(headless)Add explicit TRANSLATING stage and rename several classes by the way.

This commit is contained in:
jerryjzhang
2024-07-07 09:30:32 +08:00
parent 4d7bfe07aa
commit 64786cb0ef
56 changed files with 519 additions and 538 deletions

View File

@@ -49,7 +49,7 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
@@ -63,7 +63,7 @@ import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.utils.WorkflowEngine;
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
@@ -115,12 +115,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Autowired
private DataSetService dataSetService;
@Autowired
private WorkflowEngine workflowEngine;
private ChatWorkflowEngine chatWorkflowEngine;
@Override
public MapResp performMapping(QueryTextReq queryTextReq) {
MapResp mapResp = new MapResp();
QueryContext queryCtx = buildQueryContext(queryTextReq);
ChatQueryContext queryCtx = buildQueryContext(queryTextReq);
ComponentFactory.getSchemaMappers().forEach(mapper -> {
mapper.map(queryCtx);
});
@@ -148,12 +148,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
public ParseResp performParsing(QueryTextReq queryTextReq) {
ParseResp parseResult = new ParseResp(queryTextReq.getChatId(), queryTextReq.getQueryText());
// build queryContext and chatContext
QueryContext queryCtx = buildQueryContext(queryTextReq);
ChatQueryContext queryCtx = buildQueryContext(queryTextReq);
// in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatContextService.getOrCreateContext(queryTextReq.getChatId());
workflowEngine.startWorkflow(queryCtx, chatCtx, parseResult);
chatWorkflowEngine.execute(queryCtx, chatCtx, parseResult);
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
@@ -161,11 +161,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return parseResult;
}
public QueryContext buildQueryContext(QueryTextReq queryTextReq) {
public ChatQueryContext buildQueryContext(QueryTextReq queryTextReq) {
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
Map<Long, List<Long>> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds();
QueryContext queryCtx = QueryContext.builder()
ChatQueryContext queryCtx = ChatQueryContext.builder()
.queryFilters(queryTextReq.getQueryFilters())
.semanticSchema(semanticSchema)
.candidateQueries(new ArrayList<>())
@@ -612,7 +612,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}
private SemanticParseInfo correctSqlReq(QuerySqlReq querySqlReq, User user) {
QueryContext queryCtx = new QueryContext();
ChatQueryContext queryCtx = new ChatQueryContext();
SemanticSchema semanticSchema = schemaService.getSemanticSchema();
queryCtx.setSemanticSchema(semanticSchema);
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();

View File

@@ -11,7 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.api.pojo.request.QueryTextReq;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.knowledge.DataSetInfoStat;
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
@@ -78,12 +78,12 @@ public class RetrieveServiceImpl implements RetrieveService {
log.debug("hanlp parse result: {}", originals);
Set<Long> dataSetIds = queryTextReq.getDataSetIds();
QueryContext queryContext = new QueryContext();
BeanUtils.copyProperties(queryTextReq, queryContext);
queryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds());
ChatQueryContext chatQueryContext = new ChatQueryContext();
BeanUtils.copyProperties(queryTextReq, chatQueryContext);
chatQueryContext.setModelIdToDataSetIds(dataSetService.getModelIdToDataSetIds());
Map<MatchText, List<HanlpMapResult>> regTextMap =
searchMatchStrategy.match(queryContext, originals, dataSetIds);
searchMatchStrategy.match(chatQueryContext, originals, dataSetIds);
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));

View File

@@ -15,7 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.server.web.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
@@ -40,8 +40,8 @@ import java.util.stream.Collectors;
public class ParseInfoProcessor implements ResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext, ChatContext chatContext) {
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(candidateQueries)) {
return;
}

View File

@@ -2,13 +2,13 @@ package com.tencent.supersonic.headless.server.processor;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
/**
* A ParseResultProcessor wraps things up before returning results to users in parse stage.
*/
public interface ResultProcessor {
void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext);
void process(ParseResp parseResp, ChatQueryContext chatQueryContext, ChatContext chatContext);
}

View File

@@ -1,88 +0,0 @@
package com.tencent.supersonic.headless.server.processor;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* SqlInfoProcessor adds intermediate S2SQL and final SQL to the parsing results
* so that technical users could verify SQL by themselves.
**/
@Slf4j
public class SqlInfoProcessor implements ResultProcessor {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
long start = System.currentTimeMillis();
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(semanticQueries)) {
return;
}
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
.collect(Collectors.toList());
addSqlInfo(queryContext, selectedParses);
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start);
}
private void addSqlInfo(QueryContext queryContext, List<SemanticParseInfo> semanticParseInfos) {
if (CollectionUtils.isEmpty(semanticParseInfos)) {
return;
}
semanticParseInfos.forEach(parseInfo -> {
try {
addSqlInfo(queryContext, parseInfo);
} catch (Exception e) {
log.warn("get sql info failed:{}", parseInfo, e);
}
});
}
private void addSqlInfo(QueryContext queryContext, SemanticParseInfo parseInfo) throws Exception {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (Objects.isNull(semanticQuery)) {
return;
}
semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class);
TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq)
.queryTypeEnum(QueryMethod.SQL).build();
TranslateResp explain = queryService.translate(translateSqlReq, queryContext.getUser());
String querySql = explain.getSql();
if (StringUtils.isBlank(querySql)) {
return;
}
SqlInfo sqlInfo = parseInfo.getSqlInfo();
if (semanticQuery instanceof LLMSqlQuery) {
keyPipelineLog.info("SqlInfoProcessor results:\n"
+ "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}",
StringUtils.normalizeSpace(sqlInfo.getS2SQL()),
StringUtils.normalizeSpace(sqlInfo.getCorrectS2SQL()),
StringUtils.normalizeSpace(querySql));
}
sqlInfo.setQuerySQL(querySql);
sqlInfo.setSourceId(explain.getSourceId());
}
}

View File

@@ -0,0 +1,148 @@
package com.tencent.supersonic.headless.server.utils;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
import com.tencent.supersonic.headless.api.pojo.enums.QueryMethod;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.request.TranslateSqlReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.TranslateResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.MapUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@Service
@Slf4j
public class ChatWorkflowEngine {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
private List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
public void execute(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) {
switch (queryCtx.getChatWorkflowState()) {
case MAPPING:
performMapping(queryCtx);
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING);
break;
case PARSING:
performParsing(queryCtx, chatCtx);
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
break;
case CORRECTING:
performCorrecting(queryCtx);
queryCtx.setChatWorkflowState(ChatWorkflowState.TRANSLATING);
break;
case TRANSLATING:
long start = System.currentTimeMillis();
performTranslating(queryCtx);
parseResult.getParseTimeCost().setSqlTime(System.currentTimeMillis() - start);
queryCtx.setChatWorkflowState(ChatWorkflowState.PROCESSING);
break;
case PROCESSING:
default:
performProcessing(queryCtx, chatCtx, parseResult);
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
break;
}
}
}
public void performMapping(ChatQueryContext queryCtx) {
if (Objects.isNull(queryCtx.getMapInfo())
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) {
schemaMappers.forEach(mapper -> mapper.map(queryCtx));
}
}
public void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) {
semanticParsers.forEach(parser -> {
parser.parse(queryCtx, chatCtx);
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
}
public void performCorrecting(ChatQueryContext queryCtx) {
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) {
if (semanticQuery instanceof RuleSemanticQuery) {
continue;
}
for (SemanticCorrector corrector : semanticCorrectors) {
corrector.correct(queryCtx, semanticQuery.getParseInfo());
if (!ChatWorkflowState.CORRECTING.equals(queryCtx.getChatWorkflowState())) {
break;
}
}
}
}
}
public void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
resultProcessors.forEach(processor -> {
processor.process(parseResult, queryCtx, chatCtx);
});
}
private void performTranslating(ChatQueryContext chatQueryContext) {
List<SemanticParseInfo> semanticParseInfos = chatQueryContext.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo)
.collect(Collectors.toList());
semanticParseInfos.forEach(parseInfo -> {
try {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (Objects.isNull(semanticQuery)) {
return;
}
semanticQuery.setParseInfo(parseInfo);
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class);
TranslateSqlReq<Object> translateSqlReq = TranslateSqlReq.builder().queryReq(semanticQueryReq)
.queryTypeEnum(QueryMethod.SQL).build();
TranslateResp explain = queryService.translate(translateSqlReq, chatQueryContext.getUser());
String querySql = explain.getSql();
if (StringUtils.isBlank(querySql)) {
return;
}
SqlInfo sqlInfo = parseInfo.getSqlInfo();
sqlInfo.setQuerySQL(querySql);
sqlInfo.setSourceId(explain.getSourceId());
keyPipelineLog.info("SqlInfoProcessor results:\n"
+ "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}",
StringUtils.normalizeSpace(sqlInfo.getS2SQL()),
StringUtils.normalizeSpace(sqlInfo.getCorrectS2SQL()),
StringUtils.normalizeSpace(querySql));
} catch (Exception e) {
log.warn("get sql info failed:{}", parseInfo, e);
}
});
}
}

View File

@@ -1,91 +0,0 @@
package com.tencent.supersonic.headless.server.utils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.QueryContext;
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.MapUtils;
import org.springframework.stereotype.Service;
import java.util.List;
import java.util.Objects;
@Service
@Slf4j
public class WorkflowEngine {
private List<SchemaMapper> schemaMappers = ComponentFactory.getSchemaMappers();
private List<SemanticParser> semanticParsers = ComponentFactory.getSemanticParsers();
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
private List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
public void startWorkflow(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
queryCtx.setWorkflowState(WorkflowState.MAPPING);
while (queryCtx.getWorkflowState() != WorkflowState.FINISHED) {
switch (queryCtx.getWorkflowState()) {
case MAPPING:
performMapping(queryCtx);
queryCtx.setWorkflowState(WorkflowState.PARSING);
break;
case PARSING:
performParsing(queryCtx, chatCtx);
queryCtx.setWorkflowState(WorkflowState.CORRECTING);
break;
case CORRECTING:
performCorrecting(queryCtx);
queryCtx.setWorkflowState(WorkflowState.PROCESSING);
break;
case PROCESSING:
default:
performProcessing(queryCtx, chatCtx, parseResult);
queryCtx.setWorkflowState(WorkflowState.FINISHED);
break;
}
}
}
public void performMapping(QueryContext queryCtx) {
if (Objects.isNull(queryCtx.getMapInfo())
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) {
schemaMappers.forEach(mapper -> mapper.map(queryCtx));
}
}
public void performParsing(QueryContext queryCtx, ChatContext chatCtx) {
semanticParsers.forEach(parser -> {
parser.parse(queryCtx, chatCtx);
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
}
public void performCorrecting(QueryContext queryCtx) {
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) {
if (semanticQuery instanceof RuleSemanticQuery) {
continue;
}
for (SemanticCorrector corrector : semanticCorrectors) {
corrector.correct(queryCtx, semanticQuery.getParseInfo());
if (!WorkflowState.CORRECTING.equals(queryCtx.getWorkflowState())) {
break;
}
}
}
}
}
public void performProcessing(QueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
resultProcessors.forEach(processor -> {
processor.process(parseResult, queryCtx, chatCtx);
});
}
}