mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
update work to doXXX and integrate the SemanticCorrector code (#351)
This commit is contained in:
@@ -30,7 +30,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||
return;
|
||||
}
|
||||
work(queryReq, semanticParseInfo);
|
||||
doCorrect(queryReq, semanticParseInfo);
|
||||
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
||||
} catch (Exception e) {
|
||||
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
||||
@@ -38,7 +38,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
}
|
||||
|
||||
|
||||
public abstract void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||
public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||
|
||||
protected Map<String, String> getFieldNameMap(Long modelId) {
|
||||
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
package com.tencent.supersonic.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
|
||||
@Slf4j
|
||||
public class GlobalAfterCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
|
||||
return;
|
||||
}
|
||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
|
||||
if (Objects.nonNull(havingExpression)) {
|
||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -20,7 +20,7 @@ import org.springframework.util.CollectionUtils;
|
||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
addGroupByFields(semanticParseInfo);
|
||||
|
||||
|
||||
@@ -5,19 +5,30 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class HavingCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
//add aggregate to all metric
|
||||
addHaving(semanticParseInfo);
|
||||
|
||||
//add having expression filed to select
|
||||
addHavingToSelect(semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addHaving(SemanticParseInfo semanticParseInfo) {
|
||||
Long modelId = semanticParseInfo.getModel().getModel();
|
||||
|
||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||
@@ -32,4 +43,17 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||
}
|
||||
|
||||
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
|
||||
return;
|
||||
}
|
||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
|
||||
if (Objects.nonNull(havingExpression)) {
|
||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -18,10 +18,10 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
replaceAlias(semanticParseInfo);
|
||||
|
||||
@@ -11,7 +11,7 @@ import org.springframework.util.CollectionUtils;
|
||||
public class SelectCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
|
||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);
|
||||
|
||||
@@ -32,7 +32,7 @@ import org.springframework.util.CollectionUtils;
|
||||
public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
addDateIfNotExist(semanticParseInfo);
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
|
||||
|
||||
try {
|
||||
work(queryContext);
|
||||
doMap(queryContext);
|
||||
} catch (Exception e) {
|
||||
log.error("work error", e);
|
||||
}
|
||||
@@ -41,7 +41,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
|
||||
}
|
||||
|
||||
public abstract void work(QueryContext queryContext);
|
||||
public abstract void doMap(QueryContext queryContext);
|
||||
|
||||
|
||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch schemaElementMatch) {
|
||||
|
||||
@@ -21,7 +21,7 @@ import org.apache.commons.lang3.StringUtils;
|
||||
public class EmbeddingMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void work(QueryContext queryContext) {
|
||||
public void doMap(QueryContext queryContext) {
|
||||
//1. query from embedding by queryText
|
||||
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
|
||||
@@ -22,7 +22,7 @@ import org.springframework.util.CollectionUtils;
|
||||
public class EntityMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void work(QueryContext queryContext) {
|
||||
public void doMap(QueryContext queryContext) {
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
|
||||
|
||||
@@ -23,7 +23,7 @@ import org.springframework.util.CollectionUtils;
|
||||
public class FuzzyNameMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void work(QueryContext queryContext) {
|
||||
public void doMap(QueryContext queryContext) {
|
||||
|
||||
List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText());
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ import org.apache.commons.collections.CollectionUtils;
|
||||
public class HanlpDictMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void work(QueryContext queryContext) {
|
||||
public void doMap(QueryContext queryContext) {
|
||||
|
||||
String queryText = queryContext.getRequest().getQueryText();
|
||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||
|
||||
@@ -11,12 +11,11 @@ com.tencent.supersonic.chat.api.component.SemanticParser=\
|
||||
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
|
||||
|
||||
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||
com.tencent.supersonic.chat.corrector.GlobalBeforeCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.SelectCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.WhereCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.HavingCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector
|
||||
com.tencent.supersonic.chat.corrector.HavingCorrector
|
||||
|
||||
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
||||
com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter
|
||||
|
||||
@@ -12,12 +12,11 @@ com.tencent.supersonic.chat.api.component.SemanticParser=\
|
||||
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
|
||||
|
||||
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||
com.tencent.supersonic.chat.corrector.GlobalBeforeCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.SelectCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.WhereCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.HavingCorrector, \
|
||||
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector
|
||||
com.tencent.supersonic.chat.corrector.HavingCorrector
|
||||
|
||||
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
||||
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter
|
||||
|
||||
Reference in New Issue
Block a user