(improvement)(chat) Do not pass default date configuration to the large model uniformly. (#1601)

This commit is contained in:
lexluo09
2024-08-24 08:08:39 +08:00
committed by GitHub
parent d2306464a6
commit bef652892b
4 changed files with 26 additions and 39 deletions

View File

@@ -1,10 +1,5 @@
package com.tencent.supersonic.common.jsqlparser; package com.tencent.supersonic.common.jsqlparser;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.BinaryExpression; import net.sf.jsqlparser.expression.BinaryExpression;
@@ -33,6 +28,12 @@ import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter; import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/** /**
* Sql Parser remove Helper * Sql Parser remove Helper
*/ */
@@ -228,7 +229,6 @@ public class SqlRemoveHelper {
if (selectStatement == null) { if (selectStatement == null) {
return sql; return sql;
} }
//SelectBody selectBody = selectStatement.getSelectBody();
if (!(selectStatement instanceof PlainSelect)) { if (!(selectStatement instanceof PlainSelect)) {
return sql; return sql;
} }

View File

@@ -61,14 +61,15 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return elements.stream(); return elements.stream();
}) })
.collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1)); .collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
if (chatQueryContext.containsPartitionDimensions(dataSetId)) {
result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName());
result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName());
result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName());
result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName()); result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName()); result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName()); result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
}
result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
return result; return result;
} }

View File

@@ -47,6 +47,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
removeFieldNames.add(TimeDimensionEnum.WEEK.getChName()); removeFieldNames.add(TimeDimensionEnum.WEEK.getChName());
removeFieldNames.add(TimeDimensionEnum.MONTH.getChName()); removeFieldNames.add(TimeDimensionEnum.MONTH.getChName());
correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
correctS2SQL = SqlRemoveHelper.removeGroupBy(correctS2SQL, removeFieldNames);
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
} }

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker; import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
@@ -231,38 +230,24 @@ public class LLMRequestService {
protected List<SchemaElement> getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) { protected List<SchemaElement> getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) {
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
Set<SchemaElement> results = new HashSet<>(); Set<SchemaElement> dimensionElements = matchedElements.stream()
.filter(element -> SchemaElementType.DIMENSION.equals(element.getElement().getType()))
if (!CollectionUtils.isEmpty(matchedElements)) { .map(SchemaElementMatch::getElement)
results = matchedElements.stream() .collect(Collectors.toSet());
.filter(element -> SchemaElementType.DIMENSION.equals(element.getElement().getType()))
.map(SchemaElementMatch::getElement)
.collect(Collectors.toSet());
}
SemanticSchema semanticSchema = queryCtx.getSemanticSchema(); SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) { if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
return new ArrayList<>(results); return new ArrayList<>(dimensionElements);
} }
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); Map<Long, DataSetSchema> dataSetSchemaMap = semanticSchema.getDataSetSchemaMap();
DataSetSchema dataSetSchema = dataSetSchemaMap.get(dataSetId);
if (dataSetSchema == null) { if (dataSetSchema == null) {
return new ArrayList<>(results); return new ArrayList<>(dimensionElements);
} }
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
SchemaElement partitionDimension = dataSetSchema.getPartitionDimension(); SchemaElement partitionDimension = dataSetSchema.getPartitionDimension();
if (partitionDimension != null) {
if (timeDefaultConfig == null || partitionDimension == null) { dimensionElements.add(partitionDimension);
return new ArrayList<>(results);
} }
return new ArrayList<>(dimensionElements);
if (Objects.equals(timeDefaultConfig.getUnit(), -1)) {
results.remove(partitionDimension);
} else {
results.add(partitionDimension);
}
return new ArrayList<>(results);
} }
} }