mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(chat) In SchemaCorrector, removing filters from linkingValue that do not exist. (#775)
This commit is contained in:
@@ -140,4 +140,19 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
return semanticSchema.getMetrics(viewId);
|
||||
}
|
||||
|
||||
protected Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
|
||||
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
|
||||
.flatMap(
|
||||
schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
}
|
||||
).collect(Collectors.toSet());
|
||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||
return dimensions;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,14 +14,12 @@ import com.tencent.supersonic.headless.api.pojo.response.ViewResp;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||
import com.tencent.supersonic.headless.server.service.ViewService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Group by" section in S2SQL.
|
||||
@@ -82,22 +80,6 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
return true;
|
||||
}
|
||||
|
||||
private Set<String> getDimensions(Long viewId, SemanticSchema semanticSchema) {
|
||||
Set<String> dimensions = semanticSchema.getDimensions(viewId).stream()
|
||||
.flatMap(
|
||||
schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
}
|
||||
).collect(Collectors.toSet());
|
||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Long viewId = semanticParseInfo.getViewId();
|
||||
//add dimension group by
|
||||
|
||||
@@ -1,22 +1,30 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform schema corrections on the Schema information in S2SQL.
|
||||
@@ -27,6 +35,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
|
||||
|
||||
correctAggFunction(semanticParseInfo);
|
||||
|
||||
replaceAlias(semanticParseInfo);
|
||||
@@ -105,4 +115,35 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereExpressionList)) {
|
||||
return;
|
||||
}
|
||||
List<ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
Set<String> dimensions = getDimensions(semanticParseInfo.getViewId(), semanticSchema);
|
||||
|
||||
if (CollectionUtils.isEmpty(linkingValues)) {
|
||||
linkingValues = new ArrayList<>();
|
||||
}
|
||||
Set<String> linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Set<String> removeFieldNames = whereExpressionList.stream()
|
||||
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
|
||||
.filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension(fieldExpression.getFieldName()))
|
||||
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator()))
|
||||
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))
|
||||
.filter(fieldExpression -> !DateUtils.isAnyDateString(fieldExpression.getFieldValue().toString()))
|
||||
.filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName()))
|
||||
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
|
||||
|
||||
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,143 @@
|
||||
package com.tencent.supersonic.chat.core.corrector;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.ViewSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.ParseResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class SchemaCorrectorTest {
|
||||
|
||||
private String json = "{\n"
|
||||
+ " \"viewId\": 1,\n"
|
||||
+ " \"llmReq\": {\n"
|
||||
+ " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n"
|
||||
+ " \"filterCondition\": {\n"
|
||||
+ " \"tableName\": null\n"
|
||||
+ " },\n"
|
||||
+ " \"schema\": {\n"
|
||||
+ " \"domainName\": \"歌曲\",\n"
|
||||
+ " \"viewName\": \"歌曲\",\n"
|
||||
+ " \"fieldNameList\": [\n"
|
||||
+ " \"商务组\",\n"
|
||||
+ " \"歌曲名\",\n"
|
||||
+ " \"播放量\",\n"
|
||||
+ " \"播放份额\",\n"
|
||||
+ " \"数据日期\"\n"
|
||||
+ " ]\n"
|
||||
+ " },\n"
|
||||
+ " \"linking\": [\n"
|
||||
+ "\n"
|
||||
+ " ],\n"
|
||||
+ " \"currentDate\": \"2024-02-24\",\n"
|
||||
+ " \"priorExts\": \"播放份额是小数; \",\n"
|
||||
+ " \"sqlGenerationMode\": \"2_pass_auto_cot\"\n"
|
||||
+ " },\n"
|
||||
+ " \"request\": null,\n"
|
||||
+ " \"commonAgentTool\": {\n"
|
||||
+ " \"id\": \"y3LqVSRL\",\n"
|
||||
+ " \"name\": \"大模型语义解析\",\n"
|
||||
+ " \"type\": \"NL2SQL_LLM\",\n"
|
||||
+ " \"viewIds\": [\n"
|
||||
+ " 1\n"
|
||||
+ " ]\n"
|
||||
+ " },\n"
|
||||
+ " \"linkingValues\": [\n"
|
||||
+ "\n"
|
||||
+ " ]\n"
|
||||
+ "}";
|
||||
|
||||
@Test
|
||||
void doCorrect() throws JsonProcessingException {
|
||||
Long viewId = 1L;
|
||||
QueryContext queryContext = buildQueryContext(viewId);
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);
|
||||
|
||||
|
||||
String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' "
|
||||
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
SqlInfo sqlInfo = new SqlInfo();
|
||||
sqlInfo.setS2SQL(sql);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
semanticParseInfo.setSqlInfo(sqlInfo);
|
||||
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setView(viewId);
|
||||
semanticParseInfo.setView(schemaElement);
|
||||
|
||||
|
||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
|
||||
SchemaCorrector schemaCorrector = new SchemaCorrector();
|
||||
schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
|
||||
|
||||
assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||
+ "ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
|
||||
parseResult = objectMapper.readValue(json, ParseResult.class);
|
||||
|
||||
List<ElementValue> linkingValues = new ArrayList<>();
|
||||
ElementValue elementValue = new ElementValue();
|
||||
elementValue.setFieldName("商务组");
|
||||
elementValue.setFieldValue("xxx");
|
||||
linkingValues.add(elementValue);
|
||||
parseResult.setLinkingValues(linkingValues);
|
||||
semanticParseInfo.getProperties().put(Constants.CONTEXT, parseResult);
|
||||
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(sql);
|
||||
semanticParseInfo.getSqlInfo().setS2SQL(sql);
|
||||
schemaCorrector.removeFilterIfNotInLinkingValue(queryContext, semanticParseInfo);
|
||||
assertEquals("SELECT 歌曲名 FROM 歌曲 WHERE 发行日期 >= '2024-01-01' "
|
||||
+ "AND 商务组 = 'xxx' ORDER BY 播放量 DESC LIMIT 10", semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
|
||||
}
|
||||
|
||||
private QueryContext buildQueryContext(Long viewId) {
|
||||
QueryContext queryContext = new QueryContext();
|
||||
List<ViewSchema> viewSchemaList = new ArrayList<>();
|
||||
ViewSchema viewSchema = new ViewSchema();
|
||||
QueryConfig queryConfig = new QueryConfig();
|
||||
viewSchema.setQueryConfig(queryConfig);
|
||||
SchemaElement schemaElement = new SchemaElement();
|
||||
schemaElement.setView(viewId);
|
||||
viewSchema.setView(schemaElement);
|
||||
Set<SchemaElement> dimensions = new HashSet<>();
|
||||
SchemaElement element1 = new SchemaElement();
|
||||
element1.setView(1L);
|
||||
element1.setName("歌曲名");
|
||||
dimensions.add(element1);
|
||||
|
||||
SchemaElement element2 = new SchemaElement();
|
||||
element2.setView(1L);
|
||||
element2.setName("商务组");
|
||||
dimensions.add(element2);
|
||||
|
||||
SchemaElement element3 = new SchemaElement();
|
||||
element3.setView(1L);
|
||||
element3.setName("发行日期");
|
||||
dimensions.add(element3);
|
||||
|
||||
viewSchema.setDimensions(dimensions);
|
||||
viewSchemaList.add(viewSchema);
|
||||
|
||||
SemanticSchema semanticSchema = new SemanticSchema(viewSchemaList);
|
||||
queryContext.setSemanticSchema(semanticSchema);
|
||||
return queryContext;
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,21 @@
|
||||
package com.tencent.supersonic.common.util;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import java.text.DateFormat;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.time.LocalDate;
|
||||
import java.time.LocalDateTime;
|
||||
import java.time.format.DateTimeFormatter;
|
||||
import java.time.format.DateTimeParseException;
|
||||
import java.time.temporal.ChronoField;
|
||||
import java.time.temporal.TemporalAdjuster;
|
||||
import java.time.temporal.TemporalAdjusters;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Calendar;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
@@ -166,4 +168,27 @@ public class DateUtils {
|
||||
return datesInRange;
|
||||
}
|
||||
|
||||
public static boolean isAnyDateString(String value) {
|
||||
List<String> formats = Arrays.asList("yyyy-MM-dd", "yyyy-MM", "yyyy/MM/dd");
|
||||
return isAnyDateString(value, formats);
|
||||
}
|
||||
|
||||
public static boolean isAnyDateString(String value, List<String> formats) {
|
||||
for (String format : formats) {
|
||||
if (isDateString(value, format)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static boolean isDateString(String value, String format) {
|
||||
try {
|
||||
DateTimeFormatter formatter = DateTimeFormatter.ofPattern(format);
|
||||
LocalDate.parse(value, formatter);
|
||||
return true;
|
||||
} catch (DateTimeParseException e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user