(improvement)(chat) filter nature throw detectModelIds in mapper and add loginSql (#144)

This commit is contained in:
lexluo09
2023-09-25 21:56:47 +08:00
committed by GitHub
parent 0774c35589
commit 5c4e80c8f8
12 changed files with 65 additions and 37 deletions

View File

@@ -37,7 +37,8 @@ public class SemanticParseInfo {
private List<SchemaElementMatch> elementMatches = new ArrayList<>();
private Map<String, Object> properties = new HashMap<>();
private EntityInfo entityInfo;
private String sql;
private String logicSql;
private String querySql;
public Long getModelId() {
return model != null ? model.getId() : 0L;

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.component.SemanticCorrector;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
@@ -14,7 +15,14 @@ import lombok.extern.slf4j.Slf4j;
@Slf4j
public abstract class BaseSemanticCorrector implements SemanticCorrector {
public static final String DATE_FIELD = "数据日期";
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
}
protected Map<String, String> getFieldToBizName(Long modelId) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();

View File

@@ -21,6 +21,7 @@ public class GlobalCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
replaceAlias(semanticCorrectInfo);
@@ -74,9 +75,8 @@ public class GlobalCorrector extends BaseSemanticCorrector {
Collectors.groupingBy(ElementValue::getFieldValue,
Collectors.mapping(ElementValue::getFieldName, Collectors.toSet())));
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames);
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(semanticCorrectInfo.getSql(),
fieldValueToFieldNames);
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -17,31 +17,30 @@ public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
super.correct(semanticCorrectInfo);
String sql = semanticCorrectInfo.getSql();
if (SqlParserSelectHelper.hasAggregateFunction(preSql)) {
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(preSql);
if (SqlParserSelectHelper.hasAggregateFunction(sql)) {
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(sql);
if (Objects.nonNull(havingExpression)) {
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(preSql, havingExpression);
semanticCorrectInfo.setPreSql(preSql);
String replaceSql = SqlParserUpdateHelper.addFunctionToSelect(sql, havingExpression);
semanticCorrectInfo.setSql(replaceSql);
}
return;
}
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(preSql));
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(preSql));
Set<String> selectFields = new HashSet<>(SqlParserSelectHelper.getSelectFields(sql));
Set<String> whereFields = new HashSet<>(SqlParserSelectHelper.getWhereFields(sql));
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) {
return;
}
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(preSql));
whereFields.addAll(SqlParserSelectHelper.getOrderByFields(sql));
whereFields.removeAll(selectFields);
whereFields.remove(TimeDimensionEnum.DAY.getName());
whereFields.remove(TimeDimensionEnum.WEEK.getName());
whereFields.remove(TimeDimensionEnum.MONTH.getName());
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(preSql, new ArrayList<>(whereFields));
semanticCorrectInfo.setPreSql(preSql);
String replaceFields = SqlParserUpdateHelper.addFieldsToSelect(sql, new ArrayList<>(whereFields));
semanticCorrectInfo.setSql(replaceFields);
}
}

View File

@@ -11,10 +11,9 @@ public class TableCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
Long modelId = semanticCorrectInfo.getParseInfo().getModelId();
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceTable(preSql, TABLE_PREFIX + modelId);
String sql = SqlParserUpdateHelper.replaceTable(semanticCorrectInfo.getSql(), TABLE_PREFIX + modelId);
semanticCorrectInfo.setSql(sql);
}

View File

@@ -30,7 +30,9 @@ import org.springframework.util.CollectionUtils;
public class WhereCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
super.correct(semanticCorrectInfo);
addDateIfNotExist(semanticCorrectInfo);
@@ -41,24 +43,27 @@ public class WhereCorrector extends BaseSemanticCorrector {
updateFieldValueByTechName(semanticCorrectInfo);
}
private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) {
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
String preSql = semanticCorrectInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to preSql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
Expression expression = null;
try {
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
} catch (JSQLParserException e) {
log.error("parseCondExpression", e);
}
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(sql);
}
}
private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
String sql = semanticCorrectInfo.getSql();
sql = SqlParserUpdateHelper.replaceFunction(sql);
semanticCorrectInfo.setSql(sql);
}
@@ -98,9 +103,7 @@ public class WhereCorrector extends BaseSemanticCorrector {
}
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceValue(preSql, aliasAndBizNameToTechName);
String sql = SqlParserUpdateHelper.replaceValue(semanticCorrectInfo.getSql(), aliasAndBizNameToTechName);
semanticCorrectInfo.setSql(sql);
return;
}

View File

@@ -37,14 +37,12 @@ public class HanlpDictMapper implements SchemaMapper {
String queryText = queryContext.getRequest().getQueryText();
List<Term> terms = HanlpHelper.getTerms(queryText);
for (Term term : terms) {
log.info("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
QueryMatchStrategy matchStrategy = ContextUtils.getBean(QueryMatchStrategy.class);
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> detectModelIds = mapperHelper.getModelIds(queryContext.getRequest());
terms = filterByModelIds(terms, detectModelIds);
Map<MatchText, List<MapResult>> matchResult = matchStrategy.match(queryContext.getRequest(), terms,
detectModelIds);
@@ -57,6 +55,26 @@ public class HanlpDictMapper implements SchemaMapper {
convertTermsToSchemaMapInfo(matches, queryContext.getMapInfo(), terms);
}
private List<Term> filterByModelIds(List<Term> terms, Set<Long> detectModelIds) {
for (Term term : terms) {
log.info("before word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
}
if (CollectionUtils.isNotEmpty(detectModelIds)) {
terms = terms.stream().filter(term -> {
Long modelId = NatureHelper.getModelId(term.getNature().toString());
if (Objects.nonNull(modelId)) {
return detectModelIds.contains(modelId);
}
return false;
}).collect(Collectors.toList());
}
for (Term term : terms) {
log.info("after filter word:{},nature:{},frequency:{}", term.word, term.nature.toString(),
term.getFrequency());
}
return terms;
}
private void convertTermsToSchemaMapInfo(List<MapResult> mapResults, SchemaMapInfo schemaMap, List<Term> terms) {
if (CollectionUtils.isEmpty(mapResults)) {

View File

@@ -132,6 +132,7 @@ public class LLMDslParser implements SemanticParser {
if (StringUtils.isEmpty(correctorSql)) {
correctorSql = semanticCorrectInfo.getSql();
}
parseInfo.setLogicSql(correctorSql);
List<FilterExpression> expressions = SqlParserSelectHelper.getFilterExpression(correctorSql);
//set dataInfo
try {

View File

@@ -165,7 +165,7 @@ public class QueryServiceImpl implements QueryService {
if (Objects.isNull(explain)) {
return;
}
parseInfo.setSql(explain.getSql());
parseInfo.setQuerySql(explain.getSql());
}
@Override

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.common.util.jsqlparser;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;

View File

@@ -32,8 +32,8 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.GlobalCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.SelectCorrector, \
com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector
com.tencent.supersonic.chat.corrector.HavingCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector

View File

@@ -32,8 +32,8 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.GlobalCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.SelectCorrector, \
com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector
com.tencent.supersonic.chat.corrector.HavingCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector