(improvement)(chat) filter nature throw detectModelIds in mapper and add loginSql (#144)

This commit is contained in:
lexluo09
2023-09-25 21:56:47 +08:00
committed by GitHub
parent 0774c35589
commit 5c4e80c8f8
12 changed files with 65 additions and 37 deletions

View File

@@ -37,7 +37,8 @@ public class SemanticParseInfo {
private List<SchemaElementMatch> elementMatches = new ArrayList<>(); private List<SchemaElementMatch> elementMatches = new ArrayList<>();
private Map<String, Object> properties = new HashMap<>(); private Map<String, Object> properties = new HashMap<>();
private EntityInfo entityInfo; private EntityInfo entityInfo;
private String sql; private String logicSql;
private String querySql;
public Long getModelId() { public Long getModelId() {
return model != null ? model.getId() : 0L; return model != null ? model.getId() : 0L;

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.component.SemanticCorrector; import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService; import com.tencent.supersonic.knowledge.service.SchemaService;
@@ -14,7 +15,14 @@ import lombok.extern.slf4j.Slf4j;
@Slf4j @Slf4j
public abstract class BaseSemanticCorrector implements SemanticCorrector { public abstract class BaseSemanticCorrector implements SemanticCorrector {
public static final String DATE_FIELD = "数据日期"; public static final String DATE_FIELD = "数据日期";
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
}
protected Map<String, String> getFieldToBizName(Long modelId) { protected Map<String, String> getFieldToBizName(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();

View File

@@ -21,6 +21,7 @@ public class GlobalCorrector extends BaseSemanticCorrector {
@Override @Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) { public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
replaceAlias(semanticCorrectInfo); replaceAlias(semanticCorrectInfo);
@@ -74,9 +75,8 @@ public class GlobalCorrector extends BaseSemanticCorrector {
Collectors.groupingBy(ElementValue::getFieldValue, Collectors.groupingBy(ElementValue::getFieldValue,
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet()))); Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
String preSql = semanticCorrectInfo.getSql(); String sql = SqlParserUpdateHelper.replaceFieldNameByValue(semanticCorrectInfo.getSql(),
semanticCorrectInfo.setPreSql(preSql); fieldValueToFieldNames);
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames);
semanticCorrectInfo.setSql(sql); semanticCorrectInfo.setSql(sql);
} }
} }

View File

@@ -17,31 +17,30 @@ public class SelectCorrector extends BaseSemanticCorrector {
@Override @Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) { public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql(); super.correct(semanticCorrectInfo);
String sql = semanticCorrectInfo.getSql();
if (SqlParserSelectHelper.hasAggregateFunction(preSql)) { if (SqlParserSelectHelper.hasAggregateFunction(sql)) {
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(preSql); Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
if (Objects.nonNull(havingExpression)) { if (Objects.nonNull(havingExpression)) {
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(preSql, havingExpression); String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(replaceSql); semanticCorrectInfo.setSql(replaceSql);
} }
return; return;
} }
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(preSql)); Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(preSql)); Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql));
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) { if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
return; return;
} }
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(preSql)); whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
whereFields.removeAll(selectFields); whereFields.removeAll(selectFields);
whereFields.remove(TimeDimensionEnum.DAY.getName()); whereFields.remove(TimeDimensionEnum.DAY.getName());
whereFields.remove(TimeDimensionEnum.WEEK.getName()); whereFields.remove(TimeDimensionEnum.WEEK.getName());
whereFields.remove(TimeDimensionEnum.MONTH.getName()); whereFields.remove(TimeDimensionEnum.MONTH.getName());
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(preSql, new ArrayList<>(whereFields)); String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(replaceFields); semanticCorrectInfo.setSql(replaceFields);
} }
} }

View File

@@ -11,10 +11,9 @@ public class TableCorrector extends BaseSemanticCorrector {
@Override @Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) { public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
Long modelId = semanticCorrectInfo.getParseInfo().getModelId(); Long modelId = semanticCorrectInfo.getParseInfo().getModelId();
String preSql = semanticCorrectInfo.getSql(); String sql = SqlParserUpdateHelper.replaceTable(semanticCorrectInfo.getSql(), TABLE_PREFIX + modelId);
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceTable(preSql, TABLE_PREFIX + modelId);
semanticCorrectInfo.setSql(sql); semanticCorrectInfo.setSql(sql);
} }

View File

@@ -30,7 +30,9 @@ import org.springframework.util.CollectionUtils;
public class WhereCorrector extends BaseSemanticCorrector { public class WhereCorrector extends BaseSemanticCorrector {
@Override @Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException { public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
addDateIfNotExist(semanticCorrectInfo); addDateIfNotExist(semanticCorrectInfo);
@@ -41,24 +43,27 @@ public class WhereCorrector extends BaseSemanticCorrector {
updateFieldValueByTechName(semanticCorrectInfo); updateFieldValueByTechName(semanticCorrectInfo);
} }
private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException { private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) {
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters()); String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
String preSql = semanticCorrectInfo.getSql(); String preSql = semanticCorrectInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) { if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to preSql :{}", queryFilter); log.info("add queryFilter to preSql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter); Expression expression = null;
try {
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
String sql = SqlParserUpdateHelper.addWhere(preSql, expression); String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(sql); semanticCorrectInfo.setSql(sql);
} }
} }
private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) { private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql(); String sql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql); sql = SqlParserUpdateHelper.replaceFunction(sql);
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
semanticCorrectInfo.setSql(sql); semanticCorrectInfo.setSql(sql);
} }
@@ -98,9 +103,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
} }
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions); Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String preSql = semanticCorrectInfo.getSql(); String sql = SqlParserUpdateHelper.replaceValue(semanticCorrectInfo.getSql(), aliasAndBizNameToTechName);
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceValue(preSql, aliasAndBizNameToTechName);
semanticCorrectInfo.setSql(sql); semanticCorrectInfo.setSql(sql);
return; return;
} }

View File

@@ -37,14 +37,12 @@ public class HanlpDictMapper implements SchemaMapper {
String queryText = queryContext.getRequest().getQueryText(); String queryText = queryContext.getRequest().getQueryText();
List<Term> terms = HanlpHelper.getTerms(queryText); List<Term> terms = HanlpHelper.getTerms(queryText);
for (Term term : terms) {
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class); QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class);
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class); MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest()); Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
terms = filterByModelIds(terms, detectModelIds);
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryContext.getRequest(), terms, Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryContext.getRequest(), terms,
detectModelIds); detectModelIds);
@@ -57,6 +55,26 @@ public class HanlpDictMapper implements SchemaMapper {
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms); convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms);
} }
private List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
for (Term term : terms) {
log.info("before word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
if (CollectionUtils.isNotEmpty(detectModelIds)) {
terms = terms.stream().filter(term -> {
Long modelId = NatureHelper.getModelId(term.getNature().toString());
if (Objects.nonNull(modelId)) {
return detectModelIds.contains(modelId);
}
return false;
}).collect(Collectors.toList());
}
for (Term term : terms) {
log.info("after filter word:{},nature:{},frequency:{}", term.word, term.nature.toString(),
term.getFrequency());
}
return terms;
}
private void convertTermsToSchemaMapInfo(List<MapResult> mapResults, SchemaMapInfo schemaMap, List<Term> terms) { private void convertTermsToSchemaMapInfo(List<MapResult> mapResults, SchemaMapInfo schemaMap, List<Term> terms) {
if (CollectionUtils.isEmpty(mapResults)) { if (CollectionUtils.isEmpty(mapResults)) {

View File

@@ -132,6 +132,7 @@ public class LLMDslParser implements SemanticParser {
if (StringUtils.isEmpty(correctorSql)) { if (StringUtils.isEmpty(correctorSql)) {
correctorSql = semanticCorrectInfo.getSql(); correctorSql = semanticCorrectInfo.getSql();
} }
parseInfo.setLogicSql(correctorSql);
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql); List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
//set dataInfo //set dataInfo
try { try {

View File

@@ -165,7 +165,7 @@ public class QueryServiceImpl implements QueryService {
if (Objects.isNull(explain)) { if (Objects.isNull(explain)) {
return; return;
} }
parseInfo.setSql(explain.getSql()); parseInfo.setQuerySql(explain.getSql());
} }
@Override @Override

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.common.util.jsqlparser; package com.tencent.supersonic.common.util.jsqlparser;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;

View File

@@ -32,8 +32,8 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.chat.api.component.SemanticCorrector=\ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.GlobalCorrector, \ com.tencent.supersonic.chat.corrector.GlobalCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \ com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.SelectCorrector, \ com.tencent.supersonic.chat.corrector.SelectCorrector, \
com.tencent.supersonic.chat.corrector.WhereCorrector, \ com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector com.tencent.supersonic.chat.corrector.HavingCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector

View File

@@ -32,8 +32,8 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.chat.api.component.SemanticCorrector=\ com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.GlobalCorrector, \ com.tencent.supersonic.chat.corrector.GlobalCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \ com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.SelectCorrector, \ com.tencent.supersonic.chat.corrector.SelectCorrector, \
com.tencent.supersonic.chat.corrector.WhereCorrector, \ com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector com.tencent.supersonic.chat.corrector.HavingCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector