mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
update work to doXXX and integrate the SemanticCorrector code (#351)
This commit is contained in:
@@ -30,7 +30,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
work(queryReq, semanticParseInfo);
|
doCorrect(queryReq, semanticParseInfo);
|
||||||
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
log.info("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
||||||
@@ -38,7 +38,7 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public abstract void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
public abstract void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo);
|
||||||
|
|
||||||
protected Map<String, String> getFieldNameMap(Long modelId) {
|
protected Map<String, String> getFieldNameMap(Long modelId) {
|
||||||
|
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
|
||||||
import java.util.Objects;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class GlobalAfterCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
|
||||||
|
|
||||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
|
||||||
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
|
|
||||||
if (Objects.nonNull(havingExpression)) {
|
|
||||||
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
|
|
||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -20,7 +20,7 @@ import org.springframework.util.CollectionUtils;
|
|||||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
addGroupByFields(semanticParseInfo);
|
addGroupByFields(semanticParseInfo);
|
||||||
|
|
||||||
|
|||||||
@@ -5,19 +5,30 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserAddHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectFunctionHelper;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class HavingCorrector extends BaseSemanticCorrector {
|
public class HavingCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
//add aggregate to all metric
|
//add aggregate to all metric
|
||||||
|
addHaving(semanticParseInfo);
|
||||||
|
|
||||||
|
//add having expression filed to select
|
||||||
|
addHavingToSelect(semanticParseInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addHaving(SemanticParseInfo semanticParseInfo) {
|
||||||
Long modelId = semanticParseInfo.getModel().getModel();
|
Long modelId = semanticParseInfo.getModel().getModel();
|
||||||
|
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
@@ -32,4 +43,17 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
|||||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||||
|
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
|
if (!SqlParserSelectFunctionHelper.hasAggregateFunction(logicSql)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(logicSql);
|
||||||
|
if (Objects.nonNull(havingExpression)) {
|
||||||
|
String replaceSql = SqlParserAddHelper.addFunctionToSelect(logicSql, havingExpression);
|
||||||
|
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,10 +18,10 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class GlobalBeforeCorrector extends BaseSemanticCorrector {
|
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
replaceAlias(semanticParseInfo);
|
replaceAlias(semanticParseInfo);
|
||||||
|
|
||||||
@@ -11,7 +11,7 @@ import org.springframework.util.CollectionUtils;
|
|||||||
public class SelectCorrector extends BaseSemanticCorrector {
|
public class SelectCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
String logicSql = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||||
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
|
List<String> aggregateFields = SqlParserSelectHelper.getAggregateFields(logicSql);
|
||||||
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);
|
List<String> selectFields = SqlParserSelectHelper.getSelectFields(logicSql);
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ import org.springframework.util.CollectionUtils;
|
|||||||
public class WhereCorrector extends BaseSemanticCorrector {
|
public class WhereCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void work(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
public void doCorrect(QueryReq queryReq, SemanticParseInfo semanticParseInfo) {
|
||||||
|
|
||||||
addDateIfNotExist(semanticParseInfo);
|
addDateIfNotExist(semanticParseInfo);
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
|
log.info("before {},mapInfo:{}", simpleName, queryContext.getMapInfo().getModelElementMatches());
|
||||||
|
|
||||||
try {
|
try {
|
||||||
work(queryContext);
|
doMap(queryContext);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("work error", e);
|
log.error("work error", e);
|
||||||
}
|
}
|
||||||
@@ -41,7 +41,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
|
log.info("after {},cost:{},mapInfo:{}", simpleName, cost, queryContext.getMapInfo().getModelElementMatches());
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract void work(QueryContext queryContext);
|
public abstract void doMap(QueryContext queryContext);
|
||||||
|
|
||||||
|
|
||||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch schemaElementMatch) {
|
public void addToSchemaMap(SchemaMapInfo schemaMap, Long modelId, SchemaElementMatch schemaElementMatch) {
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import org.apache.commons.lang3.StringUtils;
|
|||||||
public class EmbeddingMapper extends BaseMapper {
|
public class EmbeddingMapper extends BaseMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void work(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
//1. query from embedding by queryText
|
//1. query from embedding by queryText
|
||||||
|
|
||||||
String queryText = queryContext.getRequest().getQueryText();
|
String queryText = queryContext.getRequest().getQueryText();
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import org.springframework.util.CollectionUtils;
|
|||||||
public class EntityMapper extends BaseMapper {
|
public class EntityMapper extends BaseMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void work(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||||
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
for (Long modelId : schemaMapInfo.getMatchedModels()) {
|
||||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
|
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(modelId);
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import org.springframework.util.CollectionUtils;
|
|||||||
public class FuzzyNameMapper extends BaseMapper {
|
public class FuzzyNameMapper extends BaseMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void work(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
|
|
||||||
List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText());
|
List<Term> terms = HanlpHelper.getTerms(queryContext.getRequest().getQueryText());
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import org.apache.commons.collections.CollectionUtils;
|
|||||||
public class HanlpDictMapper extends BaseMapper {
|
public class HanlpDictMapper extends BaseMapper {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void work(QueryContext queryContext) {
|
public void doMap(QueryContext queryContext) {
|
||||||
|
|
||||||
String queryText = queryContext.getRequest().getQueryText();
|
String queryText = queryContext.getRequest().getQueryText();
|
||||||
List<Term> terms = HanlpHelper.getTerms(queryText);
|
List<Term> terms = HanlpHelper.getTerms(queryText);
|
||||||
|
|||||||
@@ -11,12 +11,11 @@ com.tencent.supersonic.chat.api.component.SemanticParser=\
|
|||||||
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
|
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
|
||||||
|
|
||||||
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||||
com.tencent.supersonic.chat.corrector.GlobalBeforeCorrector, \
|
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.SelectCorrector, \
|
com.tencent.supersonic.chat.corrector.SelectCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.WhereCorrector, \
|
com.tencent.supersonic.chat.corrector.WhereCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.HavingCorrector, \
|
com.tencent.supersonic.chat.corrector.HavingCorrector
|
||||||
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector
|
|
||||||
|
|
||||||
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
||||||
com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter
|
com.tencent.supersonic.knowledge.semantic.RemoteSemanticInterpreter
|
||||||
|
|||||||
@@ -12,12 +12,11 @@ com.tencent.supersonic.chat.api.component.SemanticParser=\
|
|||||||
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
|
com.tencent.supersonic.chat.parser.plugin.function.FunctionBasedParser
|
||||||
|
|
||||||
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||||
com.tencent.supersonic.chat.corrector.GlobalBeforeCorrector, \
|
com.tencent.supersonic.chat.corrector.SchemaCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.SelectCorrector, \
|
com.tencent.supersonic.chat.corrector.SelectCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.WhereCorrector, \
|
com.tencent.supersonic.chat.corrector.WhereCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.HavingCorrector, \
|
com.tencent.supersonic.chat.corrector.HavingCorrector
|
||||||
com.tencent.supersonic.chat.corrector.GlobalAfterCorrector
|
|
||||||
|
|
||||||
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
com.tencent.supersonic.chat.api.component.SemanticInterpreter=\
|
||||||
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter
|
com.tencent.supersonic.knowledge.semantic.LocalSemanticInterpreter
|
||||||
|
|||||||
Reference in New Issue
Block a user