[improvement][chat]llm parser corrector is simplified by sql distribution (#120)

This commit is contained in:
lexluo09
2023-09-21 21:57:06 +08:00
committed by GitHub
parent 5c3fd75ed4
commit 03a4719aed
19 changed files with 191 additions and 402 deletions

View File

@@ -7,6 +7,7 @@ import java.util.Map;
import java.util.stream.Collectors;
public class SemanticSchema implements Serializable {
private List<ModelSchema> modelSchemaList;
public SemanticSchema(List<ModelSchema> modelSchemaList) {
@@ -34,12 +35,28 @@ public class SemanticSchema implements Serializable {
return dimensions;
}
public List<SchemaElement> getDimensions(Long modelId) {
List<SchemaElement> dimensions = getDimensions();
return getElementsByModelId(modelId, dimensions);
}
public List<SchemaElement> getMetrics() {
List<SchemaElement> metrics = new ArrayList<>();
modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics()));
return metrics;
}
public List<SchemaElement> getMetrics(Long modelId) {
List<SchemaElement> metrics = getMetrics();
return getElementsByModelId(modelId, metrics);
}
private List<SchemaElement> getElementsByModelId(Long modelId, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.collect(Collectors.toList());
}
public List<SchemaElement> getModels() {
List<SchemaElement> models = new ArrayList<>();
modelSchemaList.stream().forEach(d -> models.add(d.getModel()));

View File

@@ -1,27 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class DateFieldCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DATE_FIELD)) {
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate);
}
semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -1,18 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FieldCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFields(preSql,
getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId()));
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -1,16 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FunctionAliasCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql(replaceAlias);
}
}

View File

@@ -1,17 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class FunctionCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
semanticCorrectInfo.setSql(sql);
}
}

View File

@@ -16,11 +16,39 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
@Slf4j
public class FieldNameCorrector extends BaseSemanticCorrector {
public class GlobalCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
replaceAlias(semanticCorrectInfo);
updateFieldNameByLinkingValue(semanticCorrectInfo);
updateFieldNameByBizName(semanticCorrectInfo);
addAggregateToMetric(semanticCorrectInfo);
}
private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) {
}
private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) {
String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql(replaceAlias);
}
private void updateFieldNameByBizName(SemanticCorrectInfo semanticCorrectInfo) {
Map<String, String> fieldToBizName = getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId());
String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldToBizName);
semanticCorrectInfo.setSql(sql);
}
private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) {
Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT);
if (Objects.isNull(context)) {
return;
@@ -45,5 +73,4 @@ public class FieldNameCorrector extends BaseSemanticCorrector {
String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames);
semanticCorrectInfo.setSql(sql);
}
}
}

View File

@@ -0,0 +1,15 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class GroupByCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
}
}

View File

@@ -0,0 +1,14 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class HavingCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
}
}

View File

@@ -1,48 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@Slf4j
public class QueryFilterAppend extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
String preSql = semanticCorrectInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to preSql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(sql);
}
}
private String getQueryFilter(QueryFilters queryFilters) {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return null;
}
return queryFilters.getFilters().stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
})
.collect(Collectors.joining(Constants.AND_UPPER));
}
}

View File

@@ -13,11 +13,12 @@ import net.sf.jsqlparser.expression.Expression;
import org.springframework.util.CollectionUtils;
@Slf4j
public class SelectFieldAppendCorrector extends BaseSemanticCorrector {
public class SelectCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
if (SqlParserSelectHelper.hasAggregateFunction(preSql)) {
Expression havingExpression = SqlParserSelectHelper.getHavingExpression(preSql);
if (Objects.nonNull(havingExpression)) {

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class TableNameCorrector extends BaseSemanticCorrector {
public class TableCorrector extends BaseSemanticCorrector {
public static final String TABLE_PREFIX = "t_";

View File

@@ -1,26 +1,92 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.StringUtil;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper;
import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper;
import com.tencent.supersonic.knowledge.service.SchemaService;
import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.util.CollectionUtils;
@Slf4j
public class FieldValueCorrector extends BaseSemanticCorrector {
public class WhereCorrector extends BaseSemanticCorrector {
@Override
public void correct(SemanticCorrectInfo semanticCorrectInfo) {
public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
addDateIfNotExist(semanticCorrectInfo);
parserDateDiffFunction(semanticCorrectInfo);
addQueryFilter(semanticCorrectInfo);
updateFieldValueByTechName(semanticCorrectInfo);
}
private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException {
String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters());
String preSql = semanticCorrectInfo.getSql();
if (StringUtils.isNotEmpty(queryFilter)) {
log.info("add queryFilter to preSql :{}", queryFilter);
Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
String sql = SqlParserUpdateHelper.addWhere(preSql, expression);
semanticCorrectInfo.setPreSql(preSql);
semanticCorrectInfo.setSql(sql);
}
}
private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) {
String preSql = semanticCorrectInfo.getSql();
semanticCorrectInfo.setPreSql(preSql);
String sql = SqlParserUpdateHelper.replaceFunction(preSql);
semanticCorrectInfo.setSql(sql);
}
private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) {
String sql = semanticCorrectInfo.getSql();
List<String> whereFields = SqlParserSelectHelper.getWhereFields(sql);
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getName())) {
String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId());
sql = SqlParserUpdateHelper.addWhere(sql, TimeDimensionEnum.DAY.getName(), currentDate);
}
semanticCorrectInfo.setSql(sql);
}
private String getQueryFilter(QueryFilters queryFilters) {
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
return null;
}
return queryFilters.getFilters().stream()
.map(filter -> {
String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName());
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
return bizNameWrap + operatorWrap + valueWrap;
})
.collect(Collectors.joining(Constants.AND_UPPER));
}
private void updateFieldValueByTechName(SemanticCorrectInfo semanticCorrectInfo) {
SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema();
Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId();
List<SchemaElement> dimensions = semanticSchema.getDimensions().stream()
@@ -39,7 +105,6 @@ public class FieldValueCorrector extends BaseSemanticCorrector {
return;
}
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
if (CollectionUtils.isEmpty(dimensions)) {
return new HashMap<>();

View File

@@ -408,27 +408,20 @@ public class LLMDslParser implements SemanticParser {
protected List<String> getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema,
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig);
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema);
results.addAll(fieldNameList);
return new ArrayList<>(results);
}
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) {
Map<Long, String> itemIdToName = getItemIdToName(modelId, semanticSchema);
Set<String> results = semanticSchema.getDimensions().stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
Set<String> metrics = semanticSchema.getMetrics().stream()
.filter(schemaElement -> modelId.equals(schemaElement.getModel()))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>(results);
return new HashSet<>();
}
Set<String> fieldNameList = matchedElements.stream()
.filter(schemaElementMatch -> {
@@ -447,13 +440,29 @@ public class LLMDslParser implements SemanticParser {
})
.filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%"))
.collect(Collectors.toSet());
results.addAll(fieldNameList);
return new ArrayList<>(results);
return fieldNameList;
}
private Set<String> getTopNFieldNames(Long modelId, SemanticSchema semanticSchema,
LLMParserConfig llmParserConfig) {
Set<String> results = semanticSchema.getDimensions(modelId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getDimensionTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
Set<String> metrics = semanticSchema.getMetrics(modelId).stream()
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(llmParserConfig.getMetricTopN())
.map(entry -> entry.getName())
.collect(Collectors.toSet());
results.addAll(metrics);
return results;
}
protected Map<Long, String> getItemIdToName(Long modelId, SemanticSchema semanticSchema) {
return semanticSchema.getDimensions().stream()
.filter(entry -> modelId.equals(entry.getModel()))
return semanticSchema.getDimensions(modelId).stream()
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
}

View File

@@ -1,45 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import static org.mockito.ArgumentMatchers.any;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
class DateFieldCorrectorTest {
@Test
void corrector() {
MockedStatic<DSLDateHelper> dslDateHelper = Mockito.mockStatic(DSLDateHelper.class);
dslDateHelper.when(() -> DSLDateHelper.getReferenceDate(any())).thenReturn("2023-08-14");
DateFieldCorrector dateFieldCorrector = new DateFieldCorrector();
SemanticParseInfo parseInfo = new SemanticParseInfo();
SchemaElement model = new SchemaElement();
model.setId(2L);
parseInfo.setModel(model);
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(歌曲名) from 歌曲库 ")
.parseInfo(parseInfo)
.build();
dateFieldCorrector.correct(semanticCorrectInfo);
Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql());
semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'")
.parseInfo(parseInfo)
.build();
dateFieldCorrector.correct(semanticCorrectInfo);
Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql());
}
}

View File

@@ -1,65 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq;
import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue;
import com.tencent.supersonic.common.pojo.Constants;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class FieldNameCorrectorTest {
@Test
void corrector() {
FieldNameCorrector corrector = new FieldNameCorrector();
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select 歌曲名 from 歌曲库 where 专辑照片 = '七里香' and 专辑名 = '流行' and 数据日期 = '2023-08-19'")
.build();
SemanticParseInfo parseInfo = new SemanticParseInfo();
DSLParseResult dslParseResult = new DSLParseResult();
LLMReq llmReq = new LLMReq();
List<ElementValue> linking = new ArrayList<>();
ElementValue elementValue = new ElementValue();
elementValue.setFieldValue("流行");
elementValue.setFieldName("歌曲风格");
linking.add(elementValue);
ElementValue elementValue2 = new ElementValue();
elementValue2.setFieldValue("七里香");
elementValue2.setFieldName("歌曲名");
linking.add(elementValue2);
ElementValue elementValue3 = new ElementValue();
elementValue3.setFieldValue("周杰伦");
elementValue3.setFieldName("歌手名");
linking.add(elementValue3);
ElementValue elementValue4 = new ElementValue();
elementValue4.setFieldValue("流行");
elementValue4.setFieldName("歌曲流派");
linking.add(elementValue4);
llmReq.setLinking(linking);
dslParseResult.setLlmReq(llmReq);
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, dslParseResult);
parseInfo.setProperties(properties);
semanticCorrectInfo.setParseInfo(parseInfo);
corrector.correct(semanticCorrectInfo);
Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE 歌曲名 = '七里香' AND 歌曲流派 = '流行' AND 数据日期 = '2023-08-19'",
semanticCorrectInfo.getSql());
}
}

View File

@@ -1,71 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import static org.mockito.Mockito.when;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaValueMap;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.knowledge.service.SchemaService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
class FieldValueCorrectorTest {
@Test
void corrector() {
MockedStatic<ContextUtils> mockContextUtils = Mockito.mockStatic(ContextUtils.class);
SchemaService mockSchemaService = Mockito.mock(SchemaService.class);
SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class);
List<SchemaElement> dimensions = new ArrayList<>();
List<SchemaValueMap> schemaValueMaps = new ArrayList<>();
SchemaValueMap value1 = new SchemaValueMap();
value1.setBizName("杰伦");
value1.setTechName("周杰伦");
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
schemaValueMaps.add(value1);
SchemaElement schemaElement = SchemaElement.builder()
.bizName("singer_name")
.name("歌手名")
.model(2L)
.schemaValueMaps(schemaValueMaps)
.build();
dimensions.add(schemaElement);
when(mockSemanticSchema.getDimensions()).thenReturn(dimensions);
when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema);
mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService);
SemanticParseInfo parseInfo = new SemanticParseInfo();
SchemaElement model = new SchemaElement();
model.setId(2L);
parseInfo.setModel(model);
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select count(song_name) from 歌曲库 where singer_name = '周先生'")
.parseInfo(parseInfo)
.build();
FieldValueCorrector corrector = new FieldValueCorrector();
corrector.correct(semanticCorrectInfo);
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql("select count(song_name) from 歌曲库 where singer_name = '杰伦'");
corrector.correct(semanticCorrectInfo);
Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql());
}
}

View File

@@ -1,46 +0,0 @@
package com.tencent.supersonic.chat.corrector;
import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
class SelectFieldAppendCorrectorTest {
@Test
void corrector() {
SelectFieldAppendCorrector corrector = new SelectFieldAppendCorrector();
SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder()
.sql("select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' "
+ "and sys_imp_date = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11")
.build();
corrector.correct(semanticCorrectInfo);
Assert.assertEquals(
"SELECT 歌曲名, 歌手名, 播放量, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE "
+ "datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '邓紫棋' "
+ "AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01'"
+ " ORDER BY 播放量 DESC LIMIT 11", semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql("select 用户名 from 内容库产品 where datediff('day', 数据日期, '2023-09-14') <= 30"
+ " group by 用户名 having sum(访问次数) > 2000");
corrector.correct(semanticCorrectInfo);
Assert.assertEquals(
"SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql());
semanticCorrectInfo.setSql("SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000");
corrector.correct(semanticCorrectInfo);
Assert.assertEquals(
"SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE "
+ "datediff('day', 数据日期, '2023-09-14') <= 30 "
+ "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql());
}
}

View File

@@ -31,12 +31,9 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.DateFieldCorrector, \
com.tencent.supersonic.chat.corrector.FunctionAliasCorrector, \
com.tencent.supersonic.chat.corrector.FieldNameCorrector, \
com.tencent.supersonic.chat.corrector.FieldCorrector, \
com.tencent.supersonic.chat.corrector.FunctionCorrector, \
com.tencent.supersonic.chat.corrector.TableNameCorrector, \
com.tencent.supersonic.chat.corrector.QueryFilterAppend, \
com.tencent.supersonic.chat.corrector.SelectFieldAppendCorrector, \
com.tencent.supersonic.chat.corrector.FieldValueCorrector
com.tencent.supersonic.chat.corrector.GlobalCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.SelectCorrector, \
com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector

View File

@@ -31,12 +31,9 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
com.tencent.supersonic.chat.api.component.SemanticCorrector=\
com.tencent.supersonic.chat.corrector.DateFieldCorrector, \
com.tencent.supersonic.chat.corrector.FunctionAliasCorrector, \
com.tencent.supersonic.chat.corrector.FieldNameCorrector, \
com.tencent.supersonic.chat.corrector.FieldCorrector, \
com.tencent.supersonic.chat.corrector.FunctionCorrector, \
com.tencent.supersonic.chat.corrector.TableNameCorrector, \
com.tencent.supersonic.chat.corrector.QueryFilterAppend, \
com.tencent.supersonic.chat.corrector.SelectFieldAppendCorrector, \
com.tencent.supersonic.chat.corrector.FieldValueCorrector
com.tencent.supersonic.chat.corrector.GlobalCorrector, \
com.tencent.supersonic.chat.corrector.TableCorrector, \
com.tencent.supersonic.chat.corrector.GroupByCorrector, \
com.tencent.supersonic.chat.corrector.SelectCorrector, \
com.tencent.supersonic.chat.corrector.WhereCorrector, \
com.tencent.supersonic.chat.corrector.HavingCorrector