update work to doXXX and integrate the SemanticCorrector code (#351)

This commit is contained in:
lexluo09
2023-11-09 22:17:13 +08:00
committed by GitHub
parent 7d33c49db8
commit e0088e8f8f
14 changed files with 42 additions and 50 deletions

View File

@@ -30,7 +30,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
return;
}
work(queryReq, semanticParseInfo);
doCorrect(queryReq, semanticParseInfo);
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
} catch (Exception e) {
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
@@ -38,7 +38,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
}
public abstract void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
protected Map<String, String> getFieldNameMap(Long modelId) {

View File

@@ -1,30 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
@Slf4j
public class GlobalAfterCorrector extends BaseSemanticCorrector {
@Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
return;
}
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
if (Objects.nonNull(havingExpression)) {
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
}
return;
}
}

View File

@@ -20,7 +20,7 @@ import org.springframework.util.CollectionUtils;
public class GroupByCorrector extends BaseSemanticCorrector {
@Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
addGroupByFields(semanticParseInfo);

View File

@@ -5,19 +5,30 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import org.springframework.util.CollectionUtils;
@Slf4j
public class HavingCorrector extends BaseSemanticCorrector {
@Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric
addHaving(semanticParseInfo);
//add having expression filed to select
addHavingToSelect(semanticParseInfo);
}
private void addHaving(SemanticParseInfo semanticParseInfo) {
Long modelId = semanticParseInfo.getModel().getModel();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
@@ -32,4 +43,17 @@ public class HavingCorrector extends BaseSemanticCorrector {
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
}
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
return;
}
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
if (Objects.nonNull(havingExpression)) {
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
}
return;
}
}

View File

@@ -18,10 +18,10 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class GlobalBeforeCorrector extends BaseSemanticCorrector {
public class SchemaCorrector extends BaseSemanticCorrector {
@Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
replaceAlias(semanticParseInfo);

View File

@@ -11,7 +11,7 @@ import org.springframework.util.CollectionUtils;
public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);

View File

@@ -32,7 +32,7 @@ import org.springframework.util.CollectionUtils;
public class WhereCorrector extends BaseSemanticCorrector {
@Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
addDateIfNotExist(semanticParseInfo);

View File

@@ -32,7 +32,7 @@ public abstract class BaseMapper implements SchemaMapper {
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
try {
work(queryContext);
doMap(queryContext);
} catch (Exception e) {
log.error("work error", e);
}
@@ -41,7 +41,7 @@ public abstract class BaseMapper implements SchemaMapper {
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
}
public abstract void work(QueryContext queryContext);
public abstract void doMap(QueryContext queryContext);
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch schemaElementMatch) {

View File

@@ -21,7 +21,7 @@ import org.apache.commons.lang3.StringUtils;
public class EmbeddingMapper extends BaseMapper {
@Override
public void work(QueryContext queryContext) {
public void doMap(QueryContext queryContext) {
//1. query from embedding by queryText
String queryText = queryContext.getRequest().getQueryText();

View File

@@ -22,7 +22,7 @@ import org.springframework.util.CollectionUtils;
public class EntityMapper extends BaseMapper {
@Override
public void work(QueryContext queryContext) {
public void doMap(QueryContext queryContext) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
for (Long modelId : schemaMapInfo.getMatchedModels()) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);

View File

@@ -23,7 +23,7 @@ import org.springframework.util.CollectionUtils;
public class FuzzyNameMapper extends BaseMapper {
@Override
public void work(QueryContext queryContext) {
public void doMap(QueryContext queryContext) {
List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText());

View File

@@ -25,7 +25,7 @@ import org.apache.commons.collections.CollectionUtils;
public class HanlpDictMapper extends BaseMapper {
@Override
public void work(QueryContext queryContext) {
public void doMap(QueryContext queryContext) {
String queryText = queryContext.getRequest().getQueryText();
List<Term> terms = HanlpHelper.getTerms(queryText);

View File

@@ -11,12 +11,11 @@ com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.GlobalBeforeCorrector, \
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
com.tencent.supersonic.chat.corrector.SelectCorrector, \
com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector, \
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector
com.tencent.supersonic.chat.corrector.HavingCorrector
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter

View File

@@ -12,12 +12,11 @@ com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.GlobalBeforeCorrector, \
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
com.tencent.supersonic.chat.corrector.SelectCorrector, \
com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector, \
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector
com.tencent.supersonic.chat.corrector.HavingCorrector
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter