mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][chat]llm parser corrector is simplified by sql distribution (#120)
This commit is contained in:
@@ -7,6 +7,7 @@ import java.util.Map;
|
|||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
public class SemanticSchema implements Serializable {
|
public class SemanticSchema implements Serializable {
|
||||||
|
|
||||||
private List<ModelSchema> modelSchemaList;
|
private List<ModelSchema> modelSchemaList;
|
||||||
|
|
||||||
public SemanticSchema(List<ModelSchema> modelSchemaList) {
|
public SemanticSchema(List<ModelSchema> modelSchemaList) {
|
||||||
@@ -34,12 +35,28 @@ public class SemanticSchema implements Serializable {
|
|||||||
return dimensions;
|
return dimensions;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getDimensions(Long modelId) {
|
||||||
|
List<SchemaElement> dimensions = getDimensions();
|
||||||
|
return getElementsByModelId(modelId, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getMetrics() {
|
public List<SchemaElement> getMetrics() {
|
||||||
List<SchemaElement> metrics = new ArrayList<>();
|
List<SchemaElement> metrics = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
|
||||||
return metrics;
|
return metrics;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public List<SchemaElement> getMetrics(Long modelId) {
|
||||||
|
List<SchemaElement> metrics = getMetrics();
|
||||||
|
return getElementsByModelId(modelId, metrics);
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<SchemaElement> getElementsByModelId(Long modelId, List<SchemaElement> elements) {
|
||||||
|
return elements.stream()
|
||||||
|
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
public List<SchemaElement> getModels() {
|
public List<SchemaElement> getModels() {
|
||||||
List<SchemaElement> models = new ArrayList<>();
|
List<SchemaElement> models = new ArrayList<>();
|
||||||
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import java.util.List;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class DateFieldCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
|
|
||||||
String sql = semanticCorrectInfo.getSql();
|
|
||||||
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
|
|
||||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DATE_FIELD)) {
|
|
||||||
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
|
|
||||||
sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate);
|
|
||||||
}
|
|
||||||
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class FieldCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
String preSql = semanticCorrectInfo.getSql();
|
|
||||||
semanticCorrectInfo.setPreSql(preSql);
|
|
||||||
String sql = SqlParserUpdateHelper.replaceFields(preSql,
|
|
||||||
getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId()));
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class FunctionAliasCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());
|
|
||||||
semanticCorrectInfo.setSql(replaceAlias);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class FunctionCorrector extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
|
||||||
String preSql = semanticCorrectInfo.getSql();
|
|
||||||
semanticCorrectInfo.setPreSql(preSql);
|
|
||||||
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -16,11 +16,39 @@ import lombok.extern.slf4j.Slf4j;
|
|||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FieldNameCorrector extends BaseSemanticCorrector {
|
public class GlobalCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
|
|
||||||
|
replaceAlias(semanticCorrectInfo);
|
||||||
|
|
||||||
|
updateFieldNameByLinkingValue(semanticCorrectInfo);
|
||||||
|
|
||||||
|
updateFieldNameByBizName(semanticCorrectInfo);
|
||||||
|
|
||||||
|
addAggregateToMetric(semanticCorrectInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
|
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());
|
||||||
|
semanticCorrectInfo.setSql(replaceAlias);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateFieldNameByBizName(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
|
|
||||||
|
Map<String, String> fieldToBizName = getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId());
|
||||||
|
|
||||||
|
String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldToBizName);
|
||||||
|
|
||||||
|
semanticCorrectInfo.setSql(sql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
|
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
|
||||||
if (Objects.isNull(context)) {
|
if (Objects.isNull(context)) {
|
||||||
return;
|
return;
|
||||||
@@ -45,5 +73,4 @@ public class FieldNameCorrector extends BaseSemanticCorrector {
|
|||||||
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames);
|
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames);
|
||||||
semanticCorrectInfo.setSql(sql);
|
semanticCorrectInfo.setSql(sql);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class HavingCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.common.util.StringUtil;
|
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.stream.Collectors;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import net.sf.jsqlparser.JSQLParserException;
|
|
||||||
import net.sf.jsqlparser.expression.Expression;
|
|
||||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
|
||||||
import org.apache.commons.collections.CollectionUtils;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class QueryFilterAppend extends BaseSemanticCorrector {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
|
|
||||||
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
|
|
||||||
String preSql = semanticCorrectInfo.getSql();
|
|
||||||
|
|
||||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
|
||||||
log.info("add queryFilter to preSql :{}", queryFilter);
|
|
||||||
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
|
||||||
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
|
|
||||||
semanticCorrectInfo.setPreSql(preSql);
|
|
||||||
semanticCorrectInfo.setSql(sql);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private String getQueryFilter(QueryFilters queryFilters) {
|
|
||||||
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return queryFilters.getFilters().stream()
|
|
||||||
.map(filter -> {
|
|
||||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
|
|
||||||
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
|
||||||
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
|
||||||
return bizNameWrap + operatorWrap + valueWrap;
|
|
||||||
})
|
|
||||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -13,11 +13,12 @@ import net.sf.jsqlparser.expression.Expression;
|
|||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
|
public class SelectCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
String preSql = semanticCorrectInfo.getSql();
|
String preSql = semanticCorrectInfo.getSql();
|
||||||
|
|
||||||
if (SqlParserSelectHelper.hasAggregateFunction(preSql)) {
|
if (SqlParserSelectHelper.hasAggregateFunction(preSql)) {
|
||||||
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(preSql);
|
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(preSql);
|
||||||
if (Objects.nonNull(havingExpression)) {
|
if (Objects.nonNull(havingExpression)) {
|
||||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TableNameCorrector extends BaseSemanticCorrector {
|
public class TableCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
public static final String TABLE_PREFIX = "t_";
|
public static final String TABLE_PREFIX = "t_";
|
||||||
|
|
||||||
@@ -1,26 +1,92 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
package com.tencent.supersonic.chat.corrector;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||||
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||||
|
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
|
||||||
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
|
||||||
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
import com.tencent.supersonic.knowledge.service.SchemaService;
|
||||||
|
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import net.sf.jsqlparser.JSQLParserException;
|
||||||
|
import net.sf.jsqlparser.expression.Expression;
|
||||||
|
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
import org.apache.logging.log4j.util.Strings;
|
import org.apache.logging.log4j.util.Strings;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class FieldValueCorrector extends BaseSemanticCorrector {
|
public class WhereCorrector extends BaseSemanticCorrector {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
|
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
|
||||||
|
|
||||||
|
addDateIfNotExist(semanticCorrectInfo);
|
||||||
|
|
||||||
|
parserDateDiffFunction(semanticCorrectInfo);
|
||||||
|
|
||||||
|
addQueryFilter(semanticCorrectInfo);
|
||||||
|
|
||||||
|
updateFieldValueByTechName(semanticCorrectInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
|
||||||
|
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
|
||||||
|
|
||||||
|
String preSql = semanticCorrectInfo.getSql();
|
||||||
|
|
||||||
|
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||||
|
log.info("add queryFilter to preSql :{}", queryFilter);
|
||||||
|
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||||
|
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
|
||||||
|
semanticCorrectInfo.setPreSql(preSql);
|
||||||
|
semanticCorrectInfo.setSql(sql);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
|
String preSql = semanticCorrectInfo.getSql();
|
||||||
|
semanticCorrectInfo.setPreSql(preSql);
|
||||||
|
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
|
||||||
|
semanticCorrectInfo.setSql(sql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
|
String sql = semanticCorrectInfo.getSql();
|
||||||
|
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
|
||||||
|
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getName())) {
|
||||||
|
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
|
||||||
|
sql = SqlParserUpdateHelper.addWhere(sql, TimeDimensionEnum.DAY.getName(), currentDate);
|
||||||
|
}
|
||||||
|
semanticCorrectInfo.setSql(sql);
|
||||||
|
}
|
||||||
|
|
||||||
|
private String getQueryFilter(QueryFilters queryFilters) {
|
||||||
|
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return queryFilters.getFilters().stream()
|
||||||
|
.map(filter -> {
|
||||||
|
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
|
||||||
|
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
||||||
|
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
||||||
|
return bizNameWrap + operatorWrap + valueWrap;
|
||||||
|
})
|
||||||
|
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void updateFieldValueByTechName(SemanticCorrectInfo semanticCorrectInfo) {
|
||||||
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
|
||||||
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId();
|
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId();
|
||||||
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
|
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
|
||||||
@@ -39,7 +105,6 @@ public class FieldValueCorrector extends BaseSemanticCorrector {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
||||||
if (CollectionUtils.isEmpty(dimensions)) {
|
if (CollectionUtils.isEmpty(dimensions)) {
|
||||||
return new HashMap<>();
|
return new HashMap<>();
|
||||||
@@ -408,27 +408,20 @@ public class LLMDslParser implements SemanticParser {
|
|||||||
|
|
||||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
|
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
|
||||||
LLMParserConfig llmParserConfig) {
|
LLMParserConfig llmParserConfig) {
|
||||||
|
|
||||||
|
Set<String> results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig);
|
||||||
|
|
||||||
|
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema);
|
||||||
|
|
||||||
|
results.addAll(fieldNameList);
|
||||||
|
return new ArrayList<>(results);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
|
||||||
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
|
||||||
|
|
||||||
Set<String> results = semanticSchema.getDimensions().stream()
|
|
||||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
|
||||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
|
||||||
.limit(llmParserConfig.getDimensionTopN())
|
|
||||||
.map(entry -> entry.getName())
|
|
||||||
.collect(Collectors.toSet());
|
|
||||||
|
|
||||||
Set<String> metrics = semanticSchema.getMetrics().stream()
|
|
||||||
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
|
|
||||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
|
||||||
.limit(llmParserConfig.getMetricTopN())
|
|
||||||
.map(entry -> entry.getName())
|
|
||||||
.collect(Collectors.toSet());
|
|
||||||
|
|
||||||
results.addAll(metrics);
|
|
||||||
|
|
||||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
|
||||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||||
return new ArrayList<>(results);
|
return new HashSet<>();
|
||||||
}
|
}
|
||||||
Set<String> fieldNameList = matchedElements.stream()
|
Set<String> fieldNameList = matchedElements.stream()
|
||||||
.filter(schemaElementMatch -> {
|
.filter(schemaElementMatch -> {
|
||||||
@@ -447,13 +440,29 @@ public class LLMDslParser implements SemanticParser {
|
|||||||
})
|
})
|
||||||
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
|
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
results.addAll(fieldNameList);
|
return fieldNameList;
|
||||||
return new ArrayList<>(results);
|
}
|
||||||
|
|
||||||
|
private Set<String> getTopNFieldNames(Long modelId, SemanticSchema semanticSchema,
|
||||||
|
LLMParserConfig llmParserConfig) {
|
||||||
|
Set<String> results = semanticSchema.getDimensions(modelId).stream()
|
||||||
|
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||||
|
.limit(llmParserConfig.getDimensionTopN())
|
||||||
|
.map(entry -> entry.getName())
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
|
||||||
|
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
|
||||||
|
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||||
|
.limit(llmParserConfig.getMetricTopN())
|
||||||
|
.map(entry -> entry.getName())
|
||||||
|
.collect(Collectors.toSet());
|
||||||
|
|
||||||
|
results.addAll(metrics);
|
||||||
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
|
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
|
||||||
return semanticSchema.getDimensions().stream()
|
return semanticSchema.getDimensions(modelId).stream()
|
||||||
.filter(entry -> modelId.equals(entry.getModel()))
|
|
||||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import static org.mockito.ArgumentMatchers.any;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.mockito.MockedStatic;
|
|
||||||
import org.mockito.Mockito;
|
|
||||||
|
|
||||||
class DateFieldCorrectorTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void corrector() {
|
|
||||||
MockedStatic<DSLDateHelper> dslDateHelper = Mockito.mockStatic(DSLDateHelper.class);
|
|
||||||
|
|
||||||
dslDateHelper.when(() -> DSLDateHelper.getReferenceDate(any())).thenReturn("2023-08-14");
|
|
||||||
DateFieldCorrector dateFieldCorrector = new DateFieldCorrector();
|
|
||||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
|
||||||
SchemaElement model = new SchemaElement();
|
|
||||||
model.setId(2L);
|
|
||||||
parseInfo.setModel(model);
|
|
||||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
|
||||||
.sql("select count(歌曲名) from 歌曲库 ")
|
|
||||||
.parseInfo(parseInfo)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
dateFieldCorrector.correct(semanticCorrectInfo);
|
|
||||||
|
|
||||||
Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql());
|
|
||||||
|
|
||||||
semanticCorrectInfo = SemanticCorrectInfo.builder()
|
|
||||||
.sql("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'")
|
|
||||||
.parseInfo(parseInfo)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
dateFieldCorrector.correct(semanticCorrectInfo);
|
|
||||||
|
|
||||||
Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql());
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
|
|
||||||
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
|
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
class FieldNameCorrectorTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void corrector() {
|
|
||||||
|
|
||||||
FieldNameCorrector corrector = new FieldNameCorrector();
|
|
||||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
|
||||||
.sql("select 歌曲名 from 歌曲库 where 专辑照片 = '七里香' and 专辑名 = '流行' and 数据日期 = '2023-08-19'")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
|
||||||
|
|
||||||
DSLParseResult dslParseResult = new DSLParseResult();
|
|
||||||
LLMReq llmReq = new LLMReq();
|
|
||||||
List<ElementValue> linking = new ArrayList<>();
|
|
||||||
ElementValue elementValue = new ElementValue();
|
|
||||||
elementValue.setFieldValue("流行");
|
|
||||||
elementValue.setFieldName("歌曲风格");
|
|
||||||
linking.add(elementValue);
|
|
||||||
|
|
||||||
ElementValue elementValue2 = new ElementValue();
|
|
||||||
elementValue2.setFieldValue("七里香");
|
|
||||||
elementValue2.setFieldName("歌曲名");
|
|
||||||
linking.add(elementValue2);
|
|
||||||
|
|
||||||
ElementValue elementValue3 = new ElementValue();
|
|
||||||
elementValue3.setFieldValue("周杰伦");
|
|
||||||
elementValue3.setFieldName("歌手名");
|
|
||||||
linking.add(elementValue3);
|
|
||||||
|
|
||||||
ElementValue elementValue4 = new ElementValue();
|
|
||||||
elementValue4.setFieldValue("流行");
|
|
||||||
elementValue4.setFieldName("歌曲流派");
|
|
||||||
linking.add(elementValue4);
|
|
||||||
|
|
||||||
llmReq.setLinking(linking);
|
|
||||||
dslParseResult.setLlmReq(llmReq);
|
|
||||||
|
|
||||||
Map<String, Object> properties = new HashMap<>();
|
|
||||||
properties.put(Constants.CONTEXT, dslParseResult);
|
|
||||||
|
|
||||||
parseInfo.setProperties(properties);
|
|
||||||
semanticCorrectInfo.setParseInfo(parseInfo);
|
|
||||||
|
|
||||||
corrector.correct(semanticCorrectInfo);
|
|
||||||
|
|
||||||
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE 歌曲名 = '七里香' AND 歌曲流派 = '流行' AND 数据日期 = '2023-08-19'",
|
|
||||||
semanticCorrectInfo.getSql());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,71 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import static org.mockito.Mockito.when;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
|
||||||
import com.tencent.supersonic.knowledge.service.SchemaService;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
import org.mockito.MockedStatic;
|
|
||||||
import org.mockito.Mockito;
|
|
||||||
|
|
||||||
class FieldValueCorrectorTest {
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void corrector() {
|
|
||||||
|
|
||||||
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
|
|
||||||
|
|
||||||
SchemaService mockSchemaService = Mockito.mock(SchemaService.class);
|
|
||||||
|
|
||||||
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
|
|
||||||
|
|
||||||
List<SchemaElement> dimensions = new ArrayList<>();
|
|
||||||
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
|
|
||||||
SchemaValueMap value1 = new SchemaValueMap();
|
|
||||||
value1.setBizName("杰伦");
|
|
||||||
value1.setTechName("周杰伦");
|
|
||||||
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
|
|
||||||
schemaValueMaps.add(value1);
|
|
||||||
|
|
||||||
SchemaElement schemaElement = SchemaElement.builder()
|
|
||||||
.bizName("singer_name")
|
|
||||||
.name("歌手名")
|
|
||||||
.model(2L)
|
|
||||||
.schemaValueMaps(schemaValueMaps)
|
|
||||||
.build();
|
|
||||||
dimensions.add(schemaElement);
|
|
||||||
|
|
||||||
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
|
|
||||||
when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema);
|
|
||||||
mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService);
|
|
||||||
|
|
||||||
SemanticParseInfo parseInfo = new SemanticParseInfo();
|
|
||||||
SchemaElement model = new SchemaElement();
|
|
||||||
model.setId(2L);
|
|
||||||
parseInfo.setModel(model);
|
|
||||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
|
||||||
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生'")
|
|
||||||
.parseInfo(parseInfo)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
FieldValueCorrector corrector = new FieldValueCorrector();
|
|
||||||
corrector.correct(semanticCorrectInfo);
|
|
||||||
|
|
||||||
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql());
|
|
||||||
|
|
||||||
semanticCorrectInfo.setSql("select count(song_name) from 歌曲库 where singer_name = '杰伦'");
|
|
||||||
corrector.correct(semanticCorrectInfo);
|
|
||||||
|
|
||||||
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
package com.tencent.supersonic.chat.corrector;
|
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
class SelectFieldAppendCorrectorTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
void corrector() {
|
|
||||||
SelectFieldAppendCorrector corrector = new SelectFieldAppendCorrector();
|
|
||||||
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
|
|
||||||
.sql("select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' "
|
|
||||||
+ "and sys_imp_date = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
corrector.correct(semanticCorrectInfo);
|
|
||||||
|
|
||||||
Assert.assertEquals(
|
|
||||||
"SELECT 歌曲名, 歌手名, 播放量, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE "
|
|
||||||
+ "datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '邓紫棋' "
|
|
||||||
+ "AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01'"
|
|
||||||
+ " ORDER BY 播放量 DESC LIMIT 11", semanticCorrectInfo.getSql());
|
|
||||||
|
|
||||||
semanticCorrectInfo.setSql("select 用户名 from 内容库产品 where datediff('day', 数据日期, '2023-09-14') <= 30"
|
|
||||||
+ " group by 用户名 having sum(访问次数) > 2000");
|
|
||||||
|
|
||||||
corrector.correct(semanticCorrectInfo);
|
|
||||||
|
|
||||||
Assert.assertEquals(
|
|
||||||
"SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
|
|
||||||
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
|
|
||||||
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql());
|
|
||||||
|
|
||||||
semanticCorrectInfo.setSql("SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
|
|
||||||
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
|
|
||||||
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000");
|
|
||||||
|
|
||||||
corrector.correct(semanticCorrectInfo);
|
|
||||||
|
|
||||||
Assert.assertEquals(
|
|
||||||
"SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
|
|
||||||
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
|
|
||||||
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -31,12 +31,9 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
|
|||||||
|
|
||||||
|
|
||||||
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||||
com.tencent.supersonic.chat.corrector.DateFieldCorrector, \
|
com.tencent.supersonic.chat.corrector.GlobalCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.FunctionAliasCorrector, \
|
com.tencent.supersonic.chat.corrector.TableCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.FieldNameCorrector, \
|
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.FieldCorrector, \
|
com.tencent.supersonic.chat.corrector.SelectCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.FunctionCorrector, \
|
com.tencent.supersonic.chat.corrector.WhereCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.TableNameCorrector, \
|
com.tencent.supersonic.chat.corrector.HavingCorrector
|
||||||
com.tencent.supersonic.chat.corrector.QueryFilterAppend, \
|
|
||||||
com.tencent.supersonic.chat.corrector.SelectFieldAppendCorrector, \
|
|
||||||
com.tencent.supersonic.chat.corrector.FieldValueCorrector
|
|
||||||
@@ -31,12 +31,9 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
|
|||||||
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
|
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
|
||||||
|
|
||||||
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
|
||||||
com.tencent.supersonic.chat.corrector.DateFieldCorrector, \
|
com.tencent.supersonic.chat.corrector.GlobalCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.FunctionAliasCorrector, \
|
com.tencent.supersonic.chat.corrector.TableCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.FieldNameCorrector, \
|
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.FieldCorrector, \
|
com.tencent.supersonic.chat.corrector.SelectCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.FunctionCorrector, \
|
com.tencent.supersonic.chat.corrector.WhereCorrector, \
|
||||||
com.tencent.supersonic.chat.corrector.TableNameCorrector, \
|
com.tencent.supersonic.chat.corrector.HavingCorrector
|
||||||
com.tencent.supersonic.chat.corrector.QueryFilterAppend, \
|
|
||||||
com.tencent.supersonic.chat.corrector.SelectFieldAppendCorrector, \
|
|
||||||
com.tencent.supersonic.chat.corrector.FieldValueCorrector
|
|
||||||
|
|||||||
Reference in New Issue
Block a user