update work to doXXX and integrate the SemanticCorrector code (#351)

This commit is contained in:
lexluo09
2023-11-09 22:17:13 +08:00
committed by GitHub
parent 7d33c49db8
commit e0088e8f8f
14 changed files with 42 additions and 50 deletions

View File

@@ -30,7 +30,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) { if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
return; return;
} }
work(queryReq, semanticParseInfo); doCorrect(queryReq, semanticParseInfo);
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo()); log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
} catch (Exception e) { } catch (Exception e) {
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e); log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
@@ -38,7 +38,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
} }
public abstract void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo); public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
protected Map<String, String> getFieldNameMap(Long modelId) { protected Map<String, String> getFieldNameMap(Long modelId) {

View File

@@ -1,30 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
@Slf4j
public class GlobalAfterCorrector extends BaseSemanticCorrector {
@Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
return;
}
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
if (Objects.nonNull(havingExpression)) {
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
}
return;
}
}

View File

@@ -20,7 +20,7 @@ import org.springframework.util.CollectionUtils;
public class GroupByCorrector extends BaseSemanticCorrector { public class GroupByCorrector extends BaseSemanticCorrector {
@Override @Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
addGroupByFields(semanticParseInfo); addGroupByFields(semanticParseInfo);

View File

@@ -5,19 +5,30 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
public class HavingCorrector extends BaseSemanticCorrector { public class HavingCorrector extends BaseSemanticCorrector {
@Override @Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
//add aggregate to all metric //add aggregate to all metric
addHaving(semanticParseInfo);
//add having expression filed to select
addHavingToSelect(semanticParseInfo);
}
private void addHaving(SemanticParseInfo semanticParseInfo) {
Long modelId = semanticParseInfo.getModel().getModel(); Long modelId = semanticParseInfo.getModel().getModel();
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
@@ -32,4 +43,17 @@ public class HavingCorrector extends BaseSemanticCorrector {
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql); semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
} }
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
return;
}
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
if (Objects.nonNull(havingExpression)) {
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
}
return;
}
} }

View File

@@ -18,10 +18,10 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
@Slf4j @Slf4j
public class GlobalBeforeCorrector extends BaseSemanticCorrector { public class SchemaCorrector extends BaseSemanticCorrector {
@Override @Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
replaceAlias(semanticParseInfo); replaceAlias(semanticParseInfo);

View File

@@ -11,7 +11,7 @@ import org.springframework.util.CollectionUtils;
public class SelectCorrector extends BaseSemanticCorrector { public class SelectCorrector extends BaseSemanticCorrector {
@Override @Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL(); String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql); List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql); List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);

View File

@@ -32,7 +32,7 @@ import org.springframework.util.CollectionUtils;
public class WhereCorrector extends BaseSemanticCorrector { public class WhereCorrector extends BaseSemanticCorrector {
@Override @Override
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) { public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
addDateIfNotExist(semanticParseInfo); addDateIfNotExist(semanticParseInfo);

View File

@@ -32,7 +32,7 @@ public abstract class BaseMapper implements SchemaMapper {
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches()); log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
try { try {
work(queryContext); doMap(queryContext);
} catch (Exception e) { } catch (Exception e) {
log.error("work error", e); log.error("work error", e);
} }
@@ -41,7 +41,7 @@ public abstract class BaseMapper implements SchemaMapper {
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches()); log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
} }
public abstract void work(QueryContext queryContext); public abstract void doMap(QueryContext queryContext);
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch schemaElementMatch) { public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch schemaElementMatch) {

View File

@@ -21,7 +21,7 @@ import org.apache.commons.lang3.StringUtils;
public class EmbeddingMapper extends BaseMapper { public class EmbeddingMapper extends BaseMapper {
@Override @Override
public void work(QueryContext queryContext) { public void doMap(QueryContext queryContext) {
//1. query from embedding by queryText //1. query from embedding by queryText
String queryText = queryContext.getRequest().getQueryText(); String queryText = queryContext.getRequest().getQueryText();

View File

@@ -22,7 +22,7 @@ import org.springframework.util.CollectionUtils;
public class EntityMapper extends BaseMapper { public class EntityMapper extends BaseMapper {
@Override @Override
public void work(QueryContext queryContext) { public void doMap(QueryContext queryContext) {
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo(); SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
for (Long modelId : schemaMapInfo.getMatchedModels()) { for (Long modelId : schemaMapInfo.getMatchedModels()) {
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId); List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);

View File

@@ -23,7 +23,7 @@ import org.springframework.util.CollectionUtils;
public class FuzzyNameMapper extends BaseMapper { public class FuzzyNameMapper extends BaseMapper {
@Override @Override
public void work(QueryContext queryContext) { public void doMap(QueryContext queryContext) {
List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText()); List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText());

View File

@@ -25,7 +25,7 @@ import org.apache.commons.collections.CollectionUtils;
public class HanlpDictMapper extends BaseMapper { public class HanlpDictMapper extends BaseMapper {
@Override @Override
public void work(QueryContext queryContext) { public void doMap(QueryContext queryContext) {
String queryText = queryContext.getRequest().getQueryText(); String queryText = queryContext.getRequest().getQueryText();
List<Term> terms = HanlpHelper.getTerms(queryText); List<Term> terms = HanlpHelper.getTerms(queryText);

View File

@@ -11,12 +11,11 @@ com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
com.tencent.supersonic.chat.api.component.SemanticCorrector=\ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.GlobalBeforeCorrector, \ com.tencent.supersonic.chat.corrector.SchemaCorrector, \
com.tencent.supersonic.chat.corrector.SelectCorrector, \ com.tencent.supersonic.chat.corrector.SelectCorrector, \
com.tencent.supersonic.chat.corrector.WhereCorrector, \ com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \ com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector, \ com.tencent.supersonic.chat.corrector.HavingCorrector
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\ com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter

View File

@@ -12,12 +12,11 @@ com.tencent.supersonic.chat.api.component.SemanticParser=\
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
com.tencent.supersonic.chat.api.component.SemanticCorrector=\ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.GlobalBeforeCorrector, \ com.tencent.supersonic.chat.corrector.SchemaCorrector, \
com.tencent.supersonic.chat.corrector.SelectCorrector, \ com.tencent.supersonic.chat.corrector.SelectCorrector, \
com.tencent.supersonic.chat.corrector.WhereCorrector, \ com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \ com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector, \ com.tencent.supersonic.chat.corrector.HavingCorrector
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\ com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter