(improvement)[build] Use Spotless to customize the code formatting (#1750)

This commit is contained in:
lexluo09
2024-10-04 00:05:04 +08:00
committed by GitHub
parent 44d1cde34f
commit 71a9954be5
521 changed files with 7811 additions and 13046 deletions

View File

@@ -41,14 +41,17 @@ public class ChatQueryContext {
private Map<Long, List<Long>> modelIdToDataSetIds;
private User user;
private boolean saveAnswer;
@Builder.Default private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
@Builder.Default
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
private QueryFilters queryFilters;
private List<SemanticQuery> candidateQueries = new ArrayList<>();
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private SemanticParseInfo contextParseInfo;
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
@JsonIgnore private SemanticSchema semanticSchema;
@JsonIgnore private ChatWorkflowState chatWorkflowState;
@JsonIgnore
private SemanticSchema semanticSchema;
@JsonIgnore
private ChatWorkflowState chatWorkflowState;
private QueryDataType queryDataType = QueryDataType.ALL;
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
@@ -58,14 +61,11 @@ public class ChatQueryContext {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
int parseShowCount =
Integer.parseInt(parserConfig.getParameterValue(ParserConfig.PARSER_SHOW_COUNT));
candidateQueries =
candidateQueries.stream()
.sorted(
Comparator.comparing(
semanticQuery -> semanticQuery.getParseInfo().getScore(),
Comparator.reverseOrder()))
.limit(parseShowCount)
.collect(Collectors.toList());
candidateQueries = candidateQueries.stream()
.sorted(Comparator.comparing(
semanticQuery -> semanticQuery.getParseInfo().getScore(),
Comparator.reverseOrder()))
.limit(parseShowCount).collect(Collectors.toList());
return candidateQueries;
}

View File

@@ -17,11 +17,10 @@ public class AggCorrector extends BaseSemanticCorrector {
addAggregate(chatQueryContext, semanticParseInfo);
}
private void addAggregate(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
List<String> sqlGroupByFields =
SqlSelectHelper.getGroupByFields(
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
private void addAggregate(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
List<String> sqlGroupByFields = SqlSelectHelper
.getGroupByFields(semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
return;
}

View File

@@ -35,20 +35,18 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return;
}
doCorrect(chatQueryContext, semanticParseInfo);
log.debug(
"sqlCorrection:{} sql:{}",
this.getClass().getSimpleName(),
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(),
semanticParseInfo.getSqlInfo());
} catch (Exception e) {
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
}
}
public abstract void doCorrect(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo);
public abstract void doCorrect(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo);
protected Map<String, String> getFieldNameMap(
ChatQueryContext chatQueryContext, Long dataSetId) {
protected Map<String, String> getFieldNameMap(ChatQueryContext chatQueryContext,
Long dataSetId) {
Map<String, String> result = getFieldNameMapFromDB(chatQueryContext, dataSetId);
if (chatQueryContext.containsPartitionDimensions(dataSetId)) {
@@ -63,8 +61,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
return result;
}
private static Map<String, String> getFieldNameMapFromDB(
ChatQueryContext chatQueryContext, Long dataSetId) {
private static Map<String, String> getFieldNameMapFromDB(ChatQueryContext chatQueryContext,
Long dataSetId) {
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
List<SchemaElement> dbAllFields = new ArrayList<>();
@@ -72,51 +70,38 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
dbAllFields.addAll(semanticSchema.getDimensions());
// support fieldName and field alias
return dbAllFields.stream()
.filter(entry -> dataSetId.equals(entry.getDataSetId()))
.flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
})
.collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
return dbAllFields.stream().filter(entry -> dataSetId.equals(entry.getDataSetId()))
.flatMap(schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
}).collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
}
protected void addAggregateToMetric(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
protected void addAggregateToMetric(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
// add aggregate to all metric
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
Long dataSetId = semanticParseInfo.getDataSet().getDataSetId();
List<SchemaElement> metrics = getMetricElements(chatQueryContext, dataSetId);
Map<String, String> metricToAggregate =
metrics.stream()
.map(
schemaElement -> {
if (Objects.isNull(schemaElement.getDefaultAgg())) {
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
}
return schemaElement;
})
.flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream()
.map(
element ->
Pair.of(
element,
schemaElement.getDefaultAgg()));
})
.collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
Map<String, String> metricToAggregate = metrics.stream().map(schemaElement -> {
if (Objects.isNull(schemaElement.getDefaultAgg())) {
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
}
return schemaElement;
}).flatMap(schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream()
.map(element -> Pair.of(element, schemaElement.getDefaultAgg()));
}).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
if (CollectionUtils.isEmpty(metricToAggregate)) {
return;
@@ -125,39 +110,36 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(aggregateSql);
}
protected List<SchemaElement> getMetricElements(
ChatQueryContext chatQueryContext, Long dataSetId) {
protected List<SchemaElement> getMetricElements(ChatQueryContext chatQueryContext,
Long dataSetId) {
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
return semanticSchema.getMetrics(dataSetId);
}
protected Set<String> getDimensions(Long dataSetId, SemanticSchema semanticSchema) {
Set<String> dimensions =
semanticSchema.getDimensions(dataSetId).stream()
.flatMap(
schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
})
.collect(Collectors.toSet());
semanticSchema.getDimensions(dataSetId).stream().flatMap(schemaElement -> {
Set<String> elements = new HashSet<>();
elements.add(schemaElement.getName());
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
elements.addAll(schemaElement.getAlias());
}
return elements.stream();
}).collect(Collectors.toSet());
dimensions.add(TimeDimensionEnum.DAY.getChName());
return dimensions;
}
protected boolean containsPartitionDimensions(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
protected boolean containsPartitionDimensions(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
Long dataSetId = semanticParseInfo.getDataSetId();
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
return dataSetSchema.containsPartitionDimensions();
}
protected void removeDateIfExist(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
protected void removeDateIfExist(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
Set<String> removeFieldNames = new HashSet<>();
removeFieldNames.addAll(TimeDimensionEnum.getChNameList());

View File

@@ -31,8 +31,8 @@ public class GroupByCorrector extends BaseSemanticCorrector {
addGroupByFields(chatQueryContext, semanticParseInfo);
}
private Boolean needAddGroupBy(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
private Boolean needAddGroupBy(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
if (!QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())) {
return false;
}
@@ -66,8 +66,8 @@ public class GroupByCorrector extends BaseSemanticCorrector {
return true;
}
private void addGroupByFields(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
private void addGroupByFields(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
Long dataSetId = semanticParseInfo.getDataSetId();
// add dimension group by
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
@@ -78,19 +78,14 @@ public class GroupByCorrector extends BaseSemanticCorrector {
List<String> selectFields = SqlSelectHelper.gePureSelectFields(correctS2SQL);
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
Set<String> groupByFields =
selectFields.stream()
.filter(field -> dimensions.contains(field))
.filter(
field -> {
if (!CollectionUtils.isEmpty(aggregateFields)
&& aggregateFields.contains(field)) {
return false;
}
return true;
})
.collect(Collectors.toSet());
semanticParseInfo
.getSqlInfo()
selectFields.stream().filter(field -> dimensions.contains(field)).filter(field -> {
if (!CollectionUtils.isEmpty(aggregateFields)
&& aggregateFields.contains(field)) {
return false;
}
return true;
}).collect(Collectors.toSet());
semanticParseInfo.getSqlInfo()
.setCorrectedS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
}
}

View File

@@ -42,10 +42,8 @@ public class HavingCorrector extends BaseSemanticCorrector {
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
Set<String> metrics =
semanticSchema.getMetrics(dataSet).stream()
.map(schemaElement -> schemaElement.getName())
.collect(Collectors.toSet());
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
if (CollectionUtils.isEmpty(metrics)) {
return;

View File

@@ -13,13 +13,13 @@ import java.util.Date;
public class S2SqlDateHelper {
public static Pair<String, String> calculateDateRange(
TimeDefaultConfig timeConfig, String timeFormat) {
public static Pair<String, String> calculateDateRange(TimeDefaultConfig timeConfig,
String timeFormat) {
return calculateDateRange(DateUtils.getBeforeDate(0), timeConfig, timeFormat);
}
public static Pair<String, String> calculateDateRange(
String currentDate, TimeDefaultConfig timeConfig, String timeFormat) {
public static Pair<String, String> calculateDateRange(String currentDate,
TimeDefaultConfig timeConfig, String timeFormat) {
Integer unit = timeConfig.getUnit();
if (timeConfig == null || unit == null || unit < 0) {
return Pair.of(null, null);

View File

@@ -46,8 +46,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
correctFieldName(chatQueryContext, semanticParseInfo);
}
private void removeDateFields(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
private void removeDateFields(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
if (containsPartitionDimensions(chatQueryContext, semanticParseInfo)) {
return;
}
@@ -61,8 +61,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
sqlInfo.setCorrectedS2SQL(sql);
}
private void correctFieldName(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
private void correctFieldName(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
Map<String, String> fieldNameMap =
getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
// add as fieldName
@@ -82,19 +82,13 @@ public class SchemaCorrector extends BaseSemanticCorrector {
}
Map<String, Set<String>> fieldValueToFieldNames =
linking.stream()
.collect(
Collectors.groupingBy(
LLMReq.ElementValue::getFieldValue,
Collectors.mapping(
LLMReq.ElementValue::getFieldName,
Collectors.toSet())));
linking.stream().collect(Collectors.groupingBy(LLMReq.ElementValue::getFieldValue,
Collectors.mapping(LLMReq.ElementValue::getFieldName, Collectors.toSet())));
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql =
SqlReplaceHelper.replaceFieldNameByValue(
sqlInfo.getCorrectedS2SQL(), fieldValueToFieldNames);
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectedS2SQL(),
fieldValueToFieldNames);
sqlInfo.setCorrectedS2SQL(sql);
}
@@ -117,27 +111,20 @@ public class SchemaCorrector extends BaseSemanticCorrector {
return;
}
Map<String, Map<String, String>> filedNameToValueMap =
linking.stream()
.collect(
Collectors.groupingBy(
LLMReq.ElementValue::getFieldName,
Collectors.mapping(
LLMReq.ElementValue::getFieldValue,
Collectors.toMap(
oldValue -> oldValue,
newValue -> newValue,
(existingValue, newValue) -> newValue))));
Map<String, Map<String, String>> filedNameToValueMap = linking.stream()
.collect(Collectors.groupingBy(LLMReq.ElementValue::getFieldName,
Collectors.mapping(LLMReq.ElementValue::getFieldValue,
Collectors.toMap(oldValue -> oldValue, newValue -> newValue,
(existingValue, newValue) -> newValue))));
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String sql =
SqlReplaceHelper.replaceValue(
sqlInfo.getCorrectedS2SQL(), filedNameToValueMap, false);
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectedS2SQL(), filedNameToValueMap,
false);
sqlInfo.setCorrectedS2SQL(sql);
}
public void removeFilterIfNotInLinkingValue(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
List<FieldExpression> whereExpressionList =
@@ -152,37 +139,21 @@ public class SchemaCorrector extends BaseSemanticCorrector {
if (CollectionUtils.isEmpty(linkingValues)) {
linkingValues = new ArrayList<>();
}
Set<String> linkingFieldNames =
linkingValues.stream()
.map(linking -> linking.getFieldName())
.collect(Collectors.toSet());
Set<String> linkingFieldNames = linkingValues.stream()
.map(linking -> linking.getFieldName()).collect(Collectors.toSet());
Set<String> removeFieldNames =
whereExpressionList.stream()
.filter(
fieldExpression ->
StringUtils.isBlank(fieldExpression.getFunction()))
.filter(
fieldExpression ->
!TimeDimensionEnum.containsTimeDimension(
fieldExpression.getFieldName()))
.filter(
fieldExpression ->
FilterOperatorEnum.EQUALS
.getValue()
.equals(fieldExpression.getOperator()))
.filter(
fieldExpression ->
dimensions.contains(fieldExpression.getFieldName()))
.filter(
fieldExpression ->
!DateUtils.isAnyDateString(
fieldExpression.getFieldValue().toString()))
.filter(
fieldExpression ->
!linkingFieldNames.contains(fieldExpression.getFieldName()))
.map(fieldExpression -> fieldExpression.getFieldName())
.collect(Collectors.toSet());
Set<String> removeFieldNames = whereExpressionList.stream()
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
.filter(fieldExpression -> !TimeDimensionEnum
.containsTimeDimension(fieldExpression.getFieldName()))
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue()
.equals(fieldExpression.getOperator()))
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))
.filter(fieldExpression -> !DateUtils
.isAnyDateString(fieldExpression.getFieldValue().toString()))
.filter(fieldExpression -> !linkingFieldNames
.contains(fieldExpression.getFieldName()))
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
sqlInfo.setCorrectedS2SQL(sql);

View File

@@ -34,8 +34,7 @@ public class SelectCorrector extends BaseSemanticCorrector {
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
// If the number of aggregated fields is equal to the number of queried fields, do not add
// fields to select.
if (!CollectionUtils.isEmpty(aggregateFields)
&& !CollectionUtils.isEmpty(selectFields)
if (!CollectionUtils.isEmpty(aggregateFields) && !CollectionUtils.isEmpty(selectFields)
&& aggregateFields.size() == selectFields.size()) {
return;
}
@@ -43,10 +42,8 @@ public class SelectCorrector extends BaseSemanticCorrector {
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL);
}
protected String addFieldsToSelect(
ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo,
String correctS2SQL) {
protected String addFieldsToSelect(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo, String correctS2SQL) {
correctS2SQL = addTagDefaultFields(chatQueryContext, semanticParseInfo, correctS2SQL);
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
@@ -69,10 +66,8 @@ public class SelectCorrector extends BaseSemanticCorrector {
return addFieldsToSelectSql;
}
private String addTagDefaultFields(
ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo,
String correctS2SQL) {
private String addTagDefaultFields(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo, String correctS2SQL) {
// If it is in DETAIL mode and select *, add default metrics and dimensions.
boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL);
if (!(hasAsterisk && QueryType.DETAIL.equals(semanticParseInfo.getQueryType()))) {
@@ -84,17 +79,13 @@ public class SelectCorrector extends BaseSemanticCorrector {
Set<String> needAddDefaultFields = new HashSet<>();
if (Objects.nonNull(dataSetSchema)) {
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultMetrics())) {
Set<String> metrics =
dataSetSchema.getTagDefaultMetrics().stream()
.map(schemaElement -> schemaElement.getName())
.collect(Collectors.toSet());
Set<String> metrics = dataSetSchema.getTagDefaultMetrics().stream()
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
needAddDefaultFields.addAll(metrics);
}
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultDimensions())) {
Set<String> dimensions =
dataSetSchema.getTagDefaultDimensions().stream()
.map(schemaElement -> schemaElement.getName())
.collect(Collectors.toSet());
Set<String> dimensions = dataSetSchema.getTagDefaultDimensions().stream()
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
needAddDefaultFields.addAll(dimensions);
}
}

View File

@@ -36,15 +36,14 @@ public class TimeCorrector extends BaseSemanticCorrector {
}
}
private void addDateIfNotExist(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
private void addDateIfNotExist(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
Long dataSetId = semanticParseInfo.getDataSetId();
DataSetSchema dataSetSchema =
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
if (Objects.isNull(dataSetSchema)
|| Objects.isNull(dataSetSchema.getPartitionDimension())
if (Objects.isNull(dataSetSchema) || Objects.isNull(dataSetSchema.getPartitionDimension())
|| Objects.isNull(dataSetSchema.getPartitionDimension().getName())
|| TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
return;
@@ -66,13 +65,8 @@ public class TimeCorrector extends BaseSemanticCorrector {
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
String startDateLeft = dateRange.getLeft();
String endDateRight = dateRange.getRight();
String condExpr =
String.format(
" ( %s >= '%s' and %s <= '%s' )",
partitionDimension,
startDateLeft,
partitionDimension,
endDateRight);
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )",
partitionDimension, startDateLeft, partitionDimension, endDateRight);
correctS2SQL = addConditionToSQL(correctS2SQL, condExpr);
}
}
@@ -83,8 +77,7 @@ public class TimeCorrector extends BaseSemanticCorrector {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
if (dateBoundInfo != null
&& StringUtils.isBlank(dateBoundInfo.getLowerBound())
if (dateBoundInfo != null && StringUtils.isBlank(dateBoundInfo.getLowerBound())
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound())
&& StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) {
String upperDate = dateBoundInfo.getUpperDate();

View File

@@ -31,8 +31,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
updateFieldValueByTechName(chatQueryContext, semanticParseInfo);
}
protected void addQueryFilter(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
protected void addQueryFilter(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters());
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
@@ -55,8 +55,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
return QueryFilterParser.parse(queryFilters);
}
private void updateFieldValueByTechName(
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
private void updateFieldValueByTechName(ChatQueryContext chatQueryContext,
SemanticParseInfo semanticParseInfo) {
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
Long dataSetId = semanticParseInfo.getDataSetId();
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
@@ -75,50 +75,25 @@ public class WhereCorrector extends BaseSemanticCorrector {
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(
List<SchemaElement> dimensions) {
return dimensions.stream()
.filter(
dimension ->
Objects.nonNull(dimension)
&& StringUtils.isNotEmpty(dimension.getName())
&& !CollectionUtils.isEmpty(dimension.getSchemaValueMaps()))
.collect(
Collectors.toMap(
SchemaElement::getName,
dimension ->
dimension.getSchemaValueMaps().stream()
.filter(
valueMap ->
Objects.nonNull(valueMap)
&& StringUtils.isNotEmpty(
valueMap
.getTechName()))
.flatMap(
valueMap -> {
Map<String, String> map =
new HashMap<>();
if (StringUtils.isNotEmpty(
valueMap.getBizName())) {
map.put(
valueMap.getBizName(),
valueMap.getTechName());
}
if (!CollectionUtils.isEmpty(
valueMap.getAlias())) {
valueMap.getAlias().stream()
.filter(
StringUtils
::isNotEmpty)
.forEach(
alias ->
map.put(
alias,
valueMap
.getTechName()));
}
return map.entrySet().stream();
})
.collect(
Collectors.toMap(
Map.Entry::getKey,
Map.Entry::getValue))));
.filter(dimension -> Objects.nonNull(dimension)
&& StringUtils.isNotEmpty(dimension.getName())
&& !CollectionUtils.isEmpty(dimension.getSchemaValueMaps()))
.collect(Collectors.toMap(SchemaElement::getName,
dimension -> dimension.getSchemaValueMaps().stream()
.filter(valueMap -> Objects.nonNull(valueMap)
&& StringUtils.isNotEmpty(valueMap.getTechName()))
.flatMap(valueMap -> {
Map<String, String> map = new HashMap<>();
if (StringUtils.isNotEmpty(valueMap.getBizName())) {
map.put(valueMap.getBizName(), valueMap.getTechName());
}
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
valueMap.getAlias().stream().filter(StringUtils::isNotEmpty)
.forEach(alias -> map.put(alias,
valueMap.getTechName()));
}
return map.entrySet().stream();
}).collect(
Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))));
}
}

View File

@@ -31,10 +31,7 @@ public class DatabaseMapResult extends MapResult {
@Override
public String getMapKey() {
return this.getName()
+ Constants.UNDERLINE
+ this.getSchemaElement().getId()
+ Constants.UNDERLINE
+ this.getSchemaElement().getName();
return this.getName() + Constants.UNDERLINE + this.getSchemaElement().getId()
+ Constants.UNDERLINE + this.getSchemaElement().getName();
}
}

View File

@@ -1,11 +1,8 @@
package com.tencent.supersonic.headless.chat.knowledge;
public enum DictUpdateMode {
OFFLINE_FULL("OFFLINE_FULL"),
OFFLINE_MODEL("OFFLINE_MODEL"),
REALTIME_ADD("REALTIME_ADD"),
REALTIME_DELETE("REALTIME_DELETE"),
NOT_SUPPORT("NOT_SUPPORT");
OFFLINE_FULL("OFFLINE_FULL"), OFFLINE_MODEL("OFFLINE_MODEL"), REALTIME_ADD(
"REALTIME_ADD"), REALTIME_DELETE("REALTIME_DELETE"), NOT_SUPPORT("NOT_SUPPORT");
private String value;

View File

@@ -16,49 +16,36 @@ import java.util.stream.IntStream;
/** Dictionary Attribute Util */
public class DictionaryAttributeUtil {
public static CoreDictionary.Attribute getAttribute(
CoreDictionary.Attribute old, CoreDictionary.Attribute add) {
public static CoreDictionary.Attribute getAttribute(CoreDictionary.Attribute old,
CoreDictionary.Attribute add) {
Map<Nature, Integer> map = new HashMap<>();
Map<Nature, String> originalMap = new HashMap<>();
IntStream.range(0, old.nature.length)
.boxed()
.forEach(
i -> {
map.put(old.nature[i], old.frequency[i]);
if (Objects.nonNull(old.originals)) {
originalMap.put(old.nature[i], old.originals[i]);
}
});
IntStream.range(0, add.nature.length)
.boxed()
.forEach(
i -> {
map.put(add.nature[i], add.frequency[i]);
if (Objects.nonNull(add.originals)) {
originalMap.put(add.nature[i], add.originals[i]);
}
});
IntStream.range(0, old.nature.length).boxed().forEach(i -> {
map.put(old.nature[i], old.frequency[i]);
if (Objects.nonNull(old.originals)) {
originalMap.put(old.nature[i], old.originals[i]);
}
});
IntStream.range(0, add.nature.length).boxed().forEach(i -> {
map.put(add.nature[i], add.frequency[i]);
if (Objects.nonNull(add.originals)) {
originalMap.put(add.nature[i], add.originals[i]);
}
});
List<Map.Entry<Nature, Integer>> list =
new LinkedList<Map.Entry<Nature, Integer>>(map.entrySet());
Collections.sort(
list,
new Comparator<Map.Entry<Nature, Integer>>() {
public int compare(
Map.Entry<Nature, Integer> o1, Map.Entry<Nature, Integer> o2) {
return o2.getValue() - o1.getValue();
}
});
Collections.sort(list, new Comparator<Map.Entry<Nature, Integer>>() {
public int compare(Map.Entry<Nature, Integer> o1, Map.Entry<Nature, Integer> o2) {
return o2.getValue() - o1.getValue();
}
});
String[] originals =
list.stream().map(l -> originalMap.get(l.getKey())).toArray(String[]::new);
CoreDictionary.Attribute attribute =
new CoreDictionary.Attribute(
list.stream()
.map(i -> i.getKey())
.collect(Collectors.toList())
.toArray(new Nature[0]),
list.stream().map(i -> i.getValue()).mapToInt(Integer::intValue).toArray(),
originals,
list.stream().map(i -> i.getValue()).findFirst().get());
CoreDictionary.Attribute attribute = new CoreDictionary.Attribute(
list.stream().map(i -> i.getKey()).collect(Collectors.toList())
.toArray(new Nature[0]),
list.stream().map(i -> i.getValue()).mapToInt(Integer::intValue).toArray(),
originals, list.stream().map(i -> i.getValue()).findFirst().get());
return attribute;
}
}

View File

@@ -43,8 +43,7 @@ public class HanlpMapResult extends MapResult {
@Override
public String getMapKey() {
return this.getName()
+ Constants.UNDERLINE
return this.getName() + Constants.UNDERLINE
+ String.join(Constants.UNDERLINE, this.getNatures());
}
}

View File

@@ -17,25 +17,17 @@ public class KnowledgeBaseService {
public void updateSemanticKnowledge(List<DictWord> natures) {
List<DictWord> prefixes =
natures.stream()
.filter(
entry ->
!entry.getNatureWithFrequency()
.contains(DictWordType.SUFFIX.getType()))
.collect(Collectors.toList());
List<DictWord> prefixes = natures.stream().filter(
entry -> !entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getType()))
.collect(Collectors.toList());
for (DictWord nature : prefixes) {
HanlpHelper.addToCustomDictionary(nature);
}
List<DictWord> suffixes =
natures.stream()
.filter(
entry ->
entry.getNatureWithFrequency()
.contains(DictWordType.SUFFIX.getType()))
.collect(Collectors.toList());
List<DictWord> suffixes = natures.stream().filter(
entry -> entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getType()))
.collect(Collectors.toList());
SearchService.loadSuffix(suffixes);
}
@@ -64,35 +56,23 @@ public class KnowledgeBaseService {
return HanlpHelper.getTerms(text, modelIdToDataSetIds);
}
public List<HanlpMapResult> prefixSearch(
String key,
int limit,
Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
public List<HanlpMapResult> prefixSearch(String key, int limit,
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
return prefixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
}
public List<HanlpMapResult> prefixSearchByModel(
String key,
int limit,
Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
public List<HanlpMapResult> prefixSearchByModel(String key, int limit,
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
}
public List<HanlpMapResult> suffixSearch(
String key,
int limit,
Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
public List<HanlpMapResult> suffixSearch(String key, int limit,
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
return suffixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
}
public List<HanlpMapResult> suffixSearchByModel(
String key,
int limit,
Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
public List<HanlpMapResult> suffixSearchByModel(String key, int limit,
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
return SearchService.suffixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
}
}

View File

@@ -26,23 +26,20 @@ import java.util.stream.Stream;
@Slf4j
public class MetaEmbeddingService {
@Autowired private EmbeddingService embeddingService;
@Autowired private EmbeddingConfig embeddingConfig;
@Autowired
private EmbeddingService embeddingService;
@Autowired
private EmbeddingConfig embeddingConfig;
public List<RetrieveQueryResult> retrieveQuery(
RetrieveQuery retrieveQuery,
int num,
Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int num,
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
// dataSetIds->modelIds
Set<Long> allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
if (CollectionUtils.isNotEmpty(allModels)) {
Map<String, Object> filterCondition = new HashMap<>();
filterCondition.put(
"modelId",
allModels.stream()
.map(modelId -> modelId + DictWordType.NATURE_SPILT)
filterCondition.put("modelId",
allModels.stream().map(modelId -> modelId + DictWordType.NATURE_SPILT)
.collect(Collectors.toList()));
retrieveQuery.setFilterCondition(filterCondition);
}
@@ -67,36 +64,22 @@ public class MetaEmbeddingService {
return result;
}
// Process each Retrieval object.
List<Retrieval> updatedRetrievals =
retrievals.stream()
.flatMap(
retrieval -> {
Long modelId =
Retrieval.getLongId(
retrieval.getMetadata().get("modelId"));
List<Long> dataSetIds = modelIdToDataSetIds.get(modelId);
List<Retrieval> updatedRetrievals = retrievals.stream().flatMap(retrieval -> {
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
List<Long> dataSetIds = modelIdToDataSetIds.get(modelId);
if (CollectionUtils.isEmpty(dataSetIds)) {
return Stream.of(retrieval);
}
if (CollectionUtils.isEmpty(dataSetIds)) {
return Stream.of(retrieval);
}
return dataSetIds.stream()
.map(
dataSetId -> {
Retrieval newRetrieval = new Retrieval();
BeanUtils.copyProperties(
retrieval, newRetrieval);
newRetrieval
.getMetadata()
.putIfAbsent(
"dataSetId",
dataSetId
+ Constants
.UNDERLINE);
return newRetrieval;
});
})
.collect(Collectors.toList());
return dataSetIds.stream().map(dataSetId -> {
Retrieval newRetrieval = new Retrieval();
BeanUtils.copyProperties(retrieval, newRetrieval);
newRetrieval.getMetadata().putIfAbsent("dataSetId",
dataSetId + Constants.UNDERLINE);
return newRetrieval;
});
}).collect(Collectors.toList());
result.setRetrieval(updatedRetrievals);
return result;
}

View File

@@ -60,12 +60,9 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
* @param addToSuggeterTrie
* @return
*/
public static boolean load(
String path,
Nature defaultNature,
public static boolean load(String path, Nature defaultNature,
TreeMap<String, CoreDictionary.Attribute> map,
LinkedHashSet<Nature> customNatureCollector,
boolean addToSuggeterTrie) {
LinkedHashSet<Nature> customNatureCollector, boolean addToSuggeterTrie) {
try {
String splitter = "\\s";
if (path.endsWith(".csv")) {
@@ -112,9 +109,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
attribute = new CoreDictionary.Attribute(natureCount);
for (int i = 0; i < natureCount; ++i) {
attribute.nature[i] =
LexiconUtility.convertStringToNature(
param[1 + 2 * i], customNatureCollector);
attribute.nature[i] = LexiconUtility.convertStringToNature(param[1 + 2 * i],
customNatureCollector);
attribute.frequency[i] = Integer.parseInt(param[2 + 2 * i]);
attribute.originals[i] = original;
attribute.totalFrequency += attribute.frequency[i];
@@ -133,10 +129,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
Nature nature = attribute.nature[i];
PriorityQueue<Term> priorityQueue = NATURE_TO_VALUES.get(nature.toString());
if (Objects.isNull(priorityQueue)) {
priorityQueue =
new PriorityQueue<>(
MAX_SIZE,
Comparator.comparingInt(Term::getFrequency).reversed());
priorityQueue = new PriorityQueue<>(MAX_SIZE,
Comparator.comparingInt(Term::getFrequency).reversed());
NATURE_TO_VALUES.put(nature.toString(), priorityQueue);
}
Term term = new Term(word, nature);
@@ -159,12 +153,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
logger.warning("自定义词典" + Arrays.toString(path) + "加载失败");
return false;
} else {
logger.info(
"自定义词典加载成功:"
+ this.dat.size()
+ "个词条,耗时"
+ (System.currentTimeMillis() - start)
+ "ms");
logger.info("自定义词典加载成功:" + this.dat.size() + "个词条,耗时"
+ (System.currentTimeMillis() - start) + "ms");
this.path = path;
return true;
}
@@ -180,11 +170,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
* @param addToSuggestTrie
* @return
*/
public static boolean loadMainDictionary(
String mainPath,
String[] path,
DoubleArrayTrie<CoreDictionary.Attribute> dat,
boolean isCache,
public static boolean loadMainDictionary(String mainPath, String[] path,
DoubleArrayTrie<CoreDictionary.Attribute> dat, boolean isCache,
boolean addToSuggestTrie) {
logger.info("自定义词典开始加载:" + mainPath);
if (loadDat(mainPath, dat)) {
@@ -204,9 +191,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
p = file.getParent() + File.separator + fileName.substring(0, cut);
try {
defaultNature =
LexiconUtility.convertStringToNature(
nature, customNatureCollector);
defaultNature = LexiconUtility.convertStringToNature(nature,
customNatureCollector);
} catch (Exception var16) {
logger.severe("配置文件【" + p + "】写错了!" + var16);
continue;
@@ -241,10 +227,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
attributeList.add(entry.getValue());
}
DataOutputStream out =
new DataOutputStream(
new BufferedOutputStream(
IOUtil.newOutputStream(mainPath + ".bin")));
DataOutputStream out = new DataOutputStream(
new BufferedOutputStream(IOUtil.newOutputStream(mainPath + ".bin")));
if (customNatureCollector.isEmpty()) {
for (int i = Nature.begin.ordinal() + 1; i < Nature.values().length; ++i) {
Nature nature = Nature.values()[i];
@@ -287,8 +271,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
return loadDat(path, HanLP.Config.CustomDictionaryPath, dat);
}
public static boolean loadDat(
String path, String[] customDicPath, DoubleArrayTrie<CoreDictionary.Attribute> dat) {
public static boolean loadDat(String path, String[] customDicPath,
DoubleArrayTrie<CoreDictionary.Attribute> dat) {
try {
if (HanLP.Config.CustomDictionaryAutoRefreshCache
&& DynamicCustomDictionary.isDicNeedUpdate(path, customDicPath)) {
@@ -374,8 +358,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
IOUtil.deleteFile(this.path[0] + ".bin");
Boolean loadCacheOk = this.loadDat(this.path[0], this.path, this.dat);
if (!loadCacheOk) {
return this.loadMainDictionary(
this.path[0], this.path, this.dat, true, addToSuggesterTrie);
return this.loadMainDictionary(this.path[0], this.path, this.dat, true,
addToSuggesterTrie);
}
}
return false;
@@ -389,8 +373,7 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
word = CharTable.convert(word);
}
CoreDictionary.Attribute att =
natureWithFrequency == null
? new CoreDictionary.Attribute(Nature.nz, 1)
natureWithFrequency == null ? new CoreDictionary.Attribute(Nature.nz, 1)
: CoreDictionary.Attribute.create(natureWithFrequency);
boolean isLetters = isLetters(word);
word = getWordBySpace(word);

View File

@@ -43,35 +43,23 @@ public class SearchService {
* @param key
* @return
*/
public static List<HanlpMapResult> prefixSearch(
String key,
int limit,
Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
public static List<HanlpMapResult> prefixSearch(String key, int limit,
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
return prefixSearch(key, limit, trie, modelIdToDataSetIds, detectDataSetIds);
}
public static List<HanlpMapResult> prefixSearch(
String key,
int limit,
BinTrie<List<String>> binTrie,
Map<Long, List<Long>> modelIdToDataSetIds,
public static List<HanlpMapResult> prefixSearch(String key, int limit,
BinTrie<List<String>> binTrie, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
Set<Map.Entry<String, List<String>>> result = search(key, binTrie);
List<HanlpMapResult> hanlpMapResults =
result.stream()
.map(
entry -> {
String name = entry.getKey().replace("#", " ");
double similarity = EditDistanceUtils.getSimilarity(name, key);
return new HanlpMapResult(
name, entry.getValue(), key, similarity);
})
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toList());
hanlpMapResults =
transformAndFilterByDataSet(
hanlpMapResults, modelIdToDataSetIds, detectDataSetIds, limit);
List<HanlpMapResult> hanlpMapResults = result.stream().map(entry -> {
String name = entry.getKey().replace("#", " ");
double similarity = EditDistanceUtils.getSimilarity(name, key);
return new HanlpMapResult(name, entry.getValue(), key, similarity);
}).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toList());
hanlpMapResults = transformAndFilterByDataSet(hanlpMapResults, modelIdToDataSetIds,
detectDataSetIds, limit);
return hanlpMapResults;
}
@@ -81,87 +69,55 @@ public class SearchService {
* @param key
* @return
*/
public static List<HanlpMapResult> suffixSearch(
String key,
int limit,
Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
public static List<HanlpMapResult> suffixSearch(String key, int limit,
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
String reverseDetectSegment = StringUtils.reverse(key);
return suffixSearch(
reverseDetectSegment, limit, suffixTrie, modelIdToDataSetIds, detectDataSetIds);
return suffixSearch(reverseDetectSegment, limit, suffixTrie, modelIdToDataSetIds,
detectDataSetIds);
}
public static List<HanlpMapResult> suffixSearch(
String key,
int limit,
BinTrie<List<String>> binTrie,
Map<Long, List<Long>> modelIdToDataSetIds,
public static List<HanlpMapResult> suffixSearch(String key, int limit,
BinTrie<List<String>> binTrie, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
Set<Map.Entry<String, List<String>>> result = search(key, binTrie);
List<HanlpMapResult> hanlpMapResults =
result.stream()
.map(
entry -> {
String name = entry.getKey().replace("#", " ");
List<String> natures =
entry.getValue().stream()
.map(
nature ->
nature.replaceAll(
DictWordType.SUFFIX
.getType(),
""))
.collect(Collectors.toList());
List<HanlpMapResult> hanlpMapResults = result.stream().map(entry -> {
String name = entry.getKey().replace("#", " ");
List<String> natures = entry.getValue().stream()
.map(nature -> nature.replaceAll(DictWordType.SUFFIX.getType(), ""))
.collect(Collectors.toList());
name = StringUtils.reverse(name);
double similarity = EditDistanceUtils.getSimilarity(name, key);
return new HanlpMapResult(name, natures, key, similarity);
})
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toList());
return transformAndFilterByDataSet(
hanlpMapResults, modelIdToDataSetIds, detectDataSetIds, limit);
name = StringUtils.reverse(name);
double similarity = EditDistanceUtils.getSimilarity(name, key);
return new HanlpMapResult(name, natures, key, similarity);
}).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toList());
return transformAndFilterByDataSet(hanlpMapResults, modelIdToDataSetIds, detectDataSetIds,
limit);
}
private static List<HanlpMapResult> transformAndFilterByDataSet(
List<HanlpMapResult> hanlpMapResults,
Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds,
int limit) {
return hanlpMapResults.stream()
.peek(
hanlpMapResult -> {
List<String> natures =
hanlpMapResult.getNatures().stream()
.map(
nature ->
NatureHelper.changeModel2DataSet(
nature, modelIdToDataSetIds))
.flatMap(Collection::stream)
.filter(
nature -> {
if (CollectionUtils.isEmpty(
detectDataSetIds)) {
return true;
}
Long dataSetId =
NatureHelper.getDataSetId(nature);
if (dataSetId != null) {
return detectDataSetIds.contains(
dataSetId);
}
return false;
})
.collect(Collectors.toList());
hanlpMapResult.setNatures(natures);
})
.filter(hanlpMapResult -> !CollectionUtils.isEmpty(hanlpMapResult.getNatures()))
.limit(limit)
.collect(Collectors.toList());
List<HanlpMapResult> hanlpMapResults, Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds, int limit) {
return hanlpMapResults.stream().peek(hanlpMapResult -> {
List<String> natures = hanlpMapResult.getNatures().stream()
.map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToDataSetIds))
.flatMap(Collection::stream).filter(nature -> {
if (CollectionUtils.isEmpty(detectDataSetIds)) {
return true;
}
Long dataSetId = NatureHelper.getDataSetId(nature);
if (dataSetId != null) {
return detectDataSetIds.contains(dataSetId);
}
return false;
}).collect(Collectors.toList());
hanlpMapResult.setNatures(natures);
}).filter(hanlpMapResult -> !CollectionUtils.isEmpty(hanlpMapResult.getNatures()))
.limit(limit).collect(Collectors.toList());
}
private static Set<Map.Entry<String, List<String>>> search(
String key, BinTrie<List<String>> binTrie) {
private static Set<Map.Entry<String, List<String>>> search(String key,
BinTrie<List<String>> binTrie) {
key = key.toLowerCase();
Set<Map.Entry<String, List<String>>> entrySet =
new TreeSet<Map.Entry<String, List<String>>>();
@@ -202,14 +158,12 @@ public class SearchService {
}
TreeMap<String, CoreDictionary.Attribute> map = new TreeMap();
for (DictWord suffix : suffixes) {
CoreDictionary.Attribute attributeNew =
suffix.getNatureWithFrequency() == null
? new CoreDictionary.Attribute(Nature.nz, 1)
: CoreDictionary.Attribute.create(suffix.getNatureWithFrequency());
CoreDictionary.Attribute attributeNew = suffix.getNatureWithFrequency() == null
? new CoreDictionary.Attribute(Nature.nz, 1)
: CoreDictionary.Attribute.create(suffix.getNatureWithFrequency());
if (map.containsKey(suffix.getWord())) {
attributeNew =
DictionaryAttributeUtil.getAttribute(
map.get(suffix.getWord()), attributeNew);
attributeNew = DictionaryAttributeUtil.getAttribute(map.get(suffix.getWord()),
attributeNew);
}
map.put(suffix.getWord(), attributeNew);
}
@@ -239,11 +193,8 @@ public class SearchService {
}
public static List<String> getDimensionValue(DimensionValueReq dimensionValueReq) {
String nature =
DictWordType.NATURE_SPILT
+ dimensionValueReq.getModelId()
+ DictWordType.NATURE_SPILT
+ dimensionValueReq.getElementID();
String nature = DictWordType.NATURE_SPILT + dimensionValueReq.getModelId()
+ DictWordType.NATURE_SPILT + dimensionValueReq.getElementID();
PriorityQueue<Term> terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature);
if (CollectionUtils.isEmpty(terms)) {
return new ArrayList<>();

View File

@@ -9,8 +9,8 @@ import java.util.List;
public abstract class BaseWordWithAliasBuilder extends BaseWordBuilder {
public abstract DictWord getOneWordNature(
String word, SchemaElement schemaElement, boolean isSuffix);
public abstract DictWord getOneWordNature(String word, SchemaElement schemaElement,
boolean isSuffix);
public List<DictWord> getOneWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
List<DictWord> dictWords = new ArrayList<>();

View File

@@ -29,20 +29,12 @@ public class DimensionWordBuilder extends BaseWordWithAliasBuilder {
DictWord dictWord = new DictWord();
dictWord.setWord(word);
Long modelId = schemaElement.getModel();
String nature =
DictWordType.NATURE_SPILT
+ modelId
+ DictWordType.NATURE_SPILT
+ schemaElement.getId()
+ DictWordType.DIMENSION.getType();
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT
+ schemaElement.getId() + DictWordType.DIMENSION.getType();
if (isSuffix) {
nature =
DictWordType.NATURE_SPILT
+ modelId
+ DictWordType.NATURE_SPILT
+ schemaElement.getId()
+ DictWordType.SUFFIX.getType()
+ DictWordType.DIMENSION.getType();
nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT
+ schemaElement.getId() + DictWordType.SUFFIX.getType()
+ DictWordType.DIMENSION.getType();
}
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
return dictWord;

View File

@@ -27,12 +27,8 @@ public class EntityWordBuilder extends BaseWordWithAliasBuilder {
@Override
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
String nature =
DictWordType.NATURE_SPILT
+ schemaElement.getModel()
+ DictWordType.NATURE_SPILT
+ schemaElement.getId()
+ DictWordType.ENTITY.getType();
String nature = DictWordType.NATURE_SPILT + schemaElement.getModel()
+ DictWordType.NATURE_SPILT + schemaElement.getId() + DictWordType.ENTITY.getType();
DictWord dictWord = new DictWord();
dictWord.setWord(word);
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY * 2, nature));

View File

@@ -29,20 +29,12 @@ public class MetricWordBuilder extends BaseWordWithAliasBuilder {
DictWord dictWord = new DictWord();
dictWord.setWord(word);
Long modelId = schemaElement.getModel();
String nature =
DictWordType.NATURE_SPILT
+ modelId
+ DictWordType.NATURE_SPILT
+ schemaElement.getId()
+ DictWordType.METRIC.getType();
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT
+ schemaElement.getId() + DictWordType.METRIC.getType();
if (isSuffix) {
nature =
DictWordType.NATURE_SPILT
+ modelId
+ DictWordType.NATURE_SPILT
+ schemaElement.getId()
+ DictWordType.SUFFIX.getType()
+ DictWordType.METRIC.getType();
nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT
+ schemaElement.getId() + DictWordType.SUFFIX.getType()
+ DictWordType.METRIC.getType();
}
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
return dictWord;

View File

@@ -29,20 +29,12 @@ public class TermWordBuilder extends BaseWordWithAliasBuilder {
DictWord dictWord = new DictWord();
dictWord.setWord(word);
Long dataSet = schemaElement.getDataSetId();
String nature =
DictWordType.NATURE_SPILT
+ dataSet
+ DictWordType.NATURE_SPILT
+ schemaElement.getId()
+ DictWordType.TERM.getType();
String nature = DictWordType.NATURE_SPILT + dataSet + DictWordType.NATURE_SPILT
+ schemaElement.getId() + DictWordType.TERM.getType();
if (isSuffix) {
nature =
DictWordType.NATURE_SPILT
+ dataSet
+ DictWordType.NATURE_SPILT
+ schemaElement.getId()
+ DictWordType.SUFFIX.getType()
+ DictWordType.TERM.getType();
nature = DictWordType.NATURE_SPILT + dataSet + DictWordType.NATURE_SPILT
+ schemaElement.getId() + DictWordType.SUFFIX.getType()
+ DictWordType.TERM.getType();
}
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
return dictWord;

View File

@@ -26,11 +26,8 @@ public class ValueWordBuilder extends BaseWordWithAliasBuilder {
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
DictWord dictWord = new DictWord();
Long modelId = schemaElement.getModel();
String nature =
DictWordType.NATURE_SPILT
+ modelId
+ DictWordType.NATURE_SPILT
+ schemaElement.getId();
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT
+ schemaElement.getId();
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
dictWord.setWord(word);
return dictWord;

View File

@@ -81,12 +81,8 @@ public class FileHandlerImpl implements FileHandler {
String filePath = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName;
Long fileLineNum = getFileLineNum(filePath);
Integer startLine = (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1;
Integer endLine =
Integer.valueOf(
Math.min(
dictValueReq.getCurrent() * dictValueReq.getPageSize(),
fileLineNum)
+ "");
Integer endLine = Integer.valueOf(
Math.min(dictValueReq.getCurrent() * dictValueReq.getPageSize(), fileLineNum) + "");
List<DictValueResp> dictValueRespList = getFileData(filePath, startLine, endLine);
dictValueRespPageInfo.setPageSize(dictValueReq.getPageSize());
@@ -112,12 +108,9 @@ public class FileHandlerImpl implements FileHandler {
List<DictValueResp> fileData = new ArrayList<>();
try (Stream<String> lines = Files.lines(Paths.get(filePath))) {
fileData =
lines.skip(startLine - 1)
.limit(endLine - startLine + 1)
.map(lineStr -> convert2Resp(lineStr))
.filter(line -> Objects.nonNull(line))
.collect(Collectors.toList());
fileData = lines.skip(startLine - 1).limit(endLine - startLine + 1)
.map(lineStr -> convert2Resp(lineStr)).filter(line -> Objects.nonNull(line))
.collect(Collectors.toList());
} catch (IOException e) {
log.warn("[getFileData] e:{}", e);
}
@@ -204,8 +197,8 @@ public class FileHandlerImpl implements FileHandler {
private BufferedWriter getWriter(String filePath, Boolean append) throws IOException {
if (append) {
return Files.newBufferedWriter(
Paths.get(filePath), StandardCharsets.UTF_8, StandardOpenOption.APPEND);
return Files.newBufferedWriter(Paths.get(filePath), StandardCharsets.UTF_8,
StandardOpenOption.APPEND);
}
return Files.newBufferedWriter(Paths.get(filePath), StandardCharsets.UTF_8);
}

View File

@@ -32,17 +32,15 @@ public class FileHelper {
}
private static File[] getFileList(File customFolder, String suffix) {
File[] customSubFiles =
customFolder.listFiles(
file -> {
if (file.isDirectory()) {
return false;
}
if (file.getName().toLowerCase().endsWith(suffix)) {
return true;
}
return false;
});
File[] customSubFiles = customFolder.listFiles(file -> {
if (file.isDirectory()) {
return false;
}
if (file.getName().toLowerCase().endsWith(suffix)) {
return true;
}
return false;
});
return customSubFiles;
}

View File

@@ -57,21 +57,14 @@ public class HanlpHelper {
if (segment == null) {
synchronized (HanlpHelper.class) {
if (segment == null) {
segment =
HanLP.newSegment()
.enableIndexMode(true)
.enableIndexMode(4)
.enableCustomDictionary(true)
.enableCustomDictionaryForcing(true)
.enableOffset(true)
.enableJapaneseNameRecognize(false)
.enableNameRecognize(false)
.enableAllNamedEntityRecognize(false)
.enableJapaneseNameRecognize(false)
.enableNumberQuantifierRecognize(false)
.enablePlaceRecognize(false)
.enableOrganizationRecognize(false)
.enableCustomDictionary(getDynamicCustomDictionary());
segment = HanLP.newSegment().enableIndexMode(true).enableIndexMode(4)
.enableCustomDictionary(true).enableCustomDictionaryForcing(true)
.enableOffset(true).enableJapaneseNameRecognize(false)
.enableNameRecognize(false).enableAllNamedEntityRecognize(false)
.enableJapaneseNameRecognize(false)
.enableNumberQuantifierRecognize(false).enablePlaceRecognize(false)
.enableOrganizationRecognize(false)
.enableCustomDictionary(getDynamicCustomDictionary());
}
}
}
@@ -112,8 +105,7 @@ public class HanlpHelper {
boolean reload = getDynamicCustomDictionary().reload();
if (reload) {
log.info(
"Custom dictionary has been reloaded in {} milliseconds",
log.info("Custom dictionary has been reloaded in {} milliseconds",
System.currentTimeMillis() - startTime);
}
return reload;
@@ -125,21 +117,15 @@ public class HanlpHelper {
}
String hanlpPropertiesPath = getHanlpPropertiesPath();
HanLP.Config.CustomDictionaryPath =
Arrays.stream(HanLP.Config.CustomDictionaryPath)
.map(path -> hanlpPropertiesPath + FILE_SPILT + path)
.toArray(String[]::new);
log.info(
"hanlpPropertiesPath:{},CustomDictionaryPath:{}",
hanlpPropertiesPath,
HanLP.Config.CustomDictionaryPath = Arrays.stream(HanLP.Config.CustomDictionaryPath)
.map(path -> hanlpPropertiesPath + FILE_SPILT + path).toArray(String[]::new);
log.info("hanlpPropertiesPath:{},CustomDictionaryPath:{}", hanlpPropertiesPath,
HanLP.Config.CustomDictionaryPath);
HanLP.Config.CoreDictionaryPath =
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.BiGramDictionaryPath;
HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath =
hanlpPropertiesPath
+ FILE_SPILT
+ HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath;
HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath = hanlpPropertiesPath + FILE_SPILT
+ HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath;
HanLP.Config.BiGramDictionaryPath =
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.BiGramDictionaryPath;
HanLP.Config.CoreStopWordDictionaryPath =
@@ -201,8 +187,8 @@ public class HanlpHelper {
public static boolean addToCustomDictionary(DictWord dictWord) {
log.debug("dictWord:{}", dictWord);
return getDynamicCustomDictionary()
.insert(dictWord.getWord(), dictWord.getNatureWithFrequency());
return getDynamicCustomDictionary().insert(dictWord.getWord(),
dictWord.getNatureWithFrequency());
}
public static void removeFromCustomDictionary(DictWord dictWord) {
@@ -226,8 +212,8 @@ public class HanlpHelper {
int len = natureWithFrequency.length();
log.info("filtered natureWithFrequency:{}", natureWithFrequency);
if (StringUtils.isNotBlank(natureWithFrequency)) {
getDynamicCustomDictionary()
.add(dictWord.getWord(), natureWithFrequency.substring(0, len - 1));
getDynamicCustomDictionary().add(dictWord.getWord(),
natureWithFrequency.substring(0, len - 1));
}
SearchService.remove(dictWord, natureList.toArray(new Nature[0]));
}
@@ -257,8 +243,8 @@ public class HanlpHelper {
mapResults.addAll(newResults);
}
public static <T extends MapResult> boolean addLetterOriginal(
List<T> mapResults, T mapResult, CoreDictionary.Attribute attribute) {
public static <T extends MapResult> boolean addLetterOriginal(List<T> mapResults, T mapResult,
CoreDictionary.Attribute attribute) {
if (attribute == null) {
return false;
}
@@ -268,12 +254,8 @@ public class HanlpHelper {
for (String nature : hanlpMapResult.getNatures()) {
String orig = attribute.getOriginal(Nature.fromString(nature));
if (orig != null) {
MapResult addMapResult =
new HanlpMapResult(
orig,
Arrays.asList(nature),
hanlpMapResult.getDetectWord(),
hanlpMapResult.getSimilarity());
MapResult addMapResult = new HanlpMapResult(orig, Arrays.asList(nature),
hanlpMapResult.getDetectWord(), hanlpMapResult.getSimilarity());
mapResults.add((T) addMapResult);
isAdd = true;
}
@@ -317,38 +299,30 @@ public class HanlpHelper {
return getSegment().seg(text.toLowerCase()).stream()
.filter(term -> term.getNature().startsWith(DictWordType.NATURE_SPILT))
.map(term -> transform2ApiTerm(term, modelIdToDataSetIds))
.flatMap(Collection::stream)
.collect(Collectors.toList());
.flatMap(Collection::stream).collect(Collectors.toList());
}
public static List<S2Term> getTerms(List<S2Term> terms, Set<Long> dataSetIds) {
logTerms(terms);
if (!CollectionUtils.isEmpty(dataSetIds)) {
terms =
terms.stream()
.filter(
term -> {
Long dataSetId =
NatureHelper.getDataSetId(
term.getNature().toString());
if (Objects.nonNull(dataSetId)) {
return dataSetIds.contains(dataSetId);
}
return false;
})
.collect(Collectors.toList());
terms = terms.stream().filter(term -> {
Long dataSetId = NatureHelper.getDataSetId(term.getNature().toString());
if (Objects.nonNull(dataSetId)) {
return dataSetIds.contains(dataSetId);
}
return false;
}).collect(Collectors.toList());
log.debug("terms filter by dataSetId:{}", dataSetIds);
logTerms(terms);
}
return terms;
}
public static List<S2Term> transform2ApiTerm(
Term term, Map<Long, List<Long>> modelIdToDataSetIds) {
public static List<S2Term> transform2ApiTerm(Term term,
Map<Long, List<Long>> modelIdToDataSetIds) {
List<S2Term> s2Terms = Lists.newArrayList();
List<String> natures =
NatureHelper.changeModel2DataSet(
String.valueOf(term.getNature()), modelIdToDataSetIds);
List<String> natures = NatureHelper.changeModel2DataSet(String.valueOf(term.getNature()),
modelIdToDataSetIds);
for (String nature : natures) {
S2Term s2Term = new S2Term();
BeanUtils.copyProperties(term, s2Term);
@@ -364,10 +338,7 @@ public class HanlpHelper {
return;
}
for (S2Term term : terms) {
log.debug(
"word:{},nature:{},frequency:{}",
term.word,
term.nature.toString(),
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(),
term.getFrequency());
}
}

View File

@@ -89,8 +89,8 @@ public class NatureHelper {
return null;
}
public static List<String> changeModel2DataSet(
String nature, Map<Long, List<Long>> modelIdToDataSetIds) {
public static List<String> changeModel2DataSet(String nature,
Map<Long, List<Long>> modelIdToDataSetIds) {
if (SchemaElementType.TERM.equals(NatureHelper.convertToElementType(nature))) {
return Collections.singletonList(nature);
}
@@ -99,77 +99,56 @@ public class NatureHelper {
if (CollectionUtils.isEmpty(dataSetIds)) {
return Collections.emptyList();
}
return dataSetIds.stream()
.map(dataSetId -> changeModel2DataSet(nature, dataSetId))
.filter(Objects::nonNull)
.map(String::valueOf)
.collect(Collectors.toList());
return dataSetIds.stream().map(dataSetId -> changeModel2DataSet(nature, dataSetId))
.filter(Objects::nonNull).map(String::valueOf).collect(Collectors.toList());
}
public static boolean isDimensionValueDataSetId(String nature) {
return isNatureValid(nature)
&& !isNatureType(
nature, DictWordType.METRIC, DictWordType.DIMENSION, DictWordType.TERM)
&& !isNatureType(nature, DictWordType.METRIC, DictWordType.DIMENSION,
DictWordType.TERM)
&& StringUtils.isNumeric(nature.split(DictWordType.NATURE_SPILT)[1]);
}
public static DataSetInfoStat getDataSetStat(List<S2Term> terms) {
return DataSetInfoStat.builder()
.dataSetCount(getDataSetCount(terms))
return DataSetInfoStat.builder().dataSetCount(getDataSetCount(terms))
.dimensionDataSetCount(getDimensionCount(terms))
.metricDataSetCount(getMetricCount(terms))
.dimensionValueDataSetCount(getDimensionValueCount(terms))
.build();
.dimensionValueDataSetCount(getDimensionValueCount(terms)).build();
}
private static long getDataSetCount(List<S2Term> terms) {
return terms.stream()
.filter(term -> isDataSetOrEntity(term, getDataSetByNature(term.nature)))
.count();
.filter(term -> isDataSetOrEntity(term, getDataSetByNature(term.nature))).count();
}
private static long getDimensionValueCount(List<S2Term> terms) {
return terms.stream()
.filter(term -> isDimensionValueDataSetId(term.nature.toString()))
return terms.stream().filter(term -> isDimensionValueDataSetId(term.nature.toString()))
.count();
}
private static long getDimensionCount(List<S2Term> terms) {
return terms.stream()
.filter(
term ->
term.nature.startsWith(DictWordType.NATURE_SPILT)
&& term.nature
.toString()
.endsWith(DictWordType.DIMENSION.getType()))
.filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT)
&& term.nature.toString().endsWith(DictWordType.DIMENSION.getType()))
.count();
}
private static long getMetricCount(List<S2Term> terms) {
return terms.stream()
.filter(
term ->
term.nature.startsWith(DictWordType.NATURE_SPILT)
&& term.nature
.toString()
.endsWith(DictWordType.METRIC.getType()))
.count();
return terms.stream().filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT)
&& term.nature.toString().endsWith(DictWordType.METRIC.getType())).count();
}
public static Map<Long, Map<DictWordType, Integer>> getDataSetToNatureStat(List<S2Term> terms) {
Map<Long, Map<DictWordType, Integer>> modelToNature = new HashMap<>();
terms.stream()
.filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT))
.forEach(
term -> {
DictWordType dictWordType =
DictWordType.getNatureType(term.nature.toString());
Long model = getDataSetId(term.nature.toString());
terms.stream().filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT))
.forEach(term -> {
DictWordType dictWordType = DictWordType.getNatureType(term.nature.toString());
Long model = getDataSetId(term.nature.toString());
modelToNature
.computeIfAbsent(model, k -> new HashMap<>())
.merge(dictWordType, 1, Integer::sum);
});
modelToNature.computeIfAbsent(model, k -> new HashMap<>()).merge(dictWordType,
1, Integer::sum);
});
return modelToNature;
}
@@ -177,12 +156,9 @@ public class NatureHelper {
Map<Long, Map<DictWordType, Integer>> modelToNatureStat = getDataSetToNatureStat(terms);
return modelToNatureStat.entrySet().stream()
.max(Comparator.comparingInt(entry -> entry.getValue().size()))
.map(
entry ->
modelToNatureStat.entrySet().stream()
.filter(e -> e.getValue().size() == entry.getValue().size())
.map(Map.Entry::getKey)
.collect(Collectors.toList()))
.map(entry -> modelToNatureStat.entrySet().stream()
.filter(e -> e.getValue().size() == entry.getValue().size())
.map(Map.Entry::getKey).collect(Collectors.toList()))
.orElse(Collections.emptyList());
}
@@ -190,15 +166,14 @@ public class NatureHelper {
return parseIdFromNature(nature, 2);
}
public static Set<Long> getModelIds(
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
public static Set<Long> getModelIds(Map<Long, List<Long>> modelIdToDataSetIds,
Set<Long> detectDataSetIds) {
if (CollectionUtils.isEmpty(detectDataSetIds)) {
return modelIdToDataSetIds.keySet();
}
return modelIdToDataSetIds.entrySet().stream()
.filter(entry -> !Collections.disjoint(entry.getValue(), detectDataSetIds))
.map(Map.Entry::getKey)
.collect(Collectors.toSet());
.map(Map.Entry::getKey).collect(Collectors.toSet());
}
public static Long parseIdFromNature(String nature, int index) {

View File

@@ -30,9 +30,7 @@ public abstract class BaseMapper implements SchemaMapper {
String simpleName = this.getClass().getSimpleName();
long startTime = System.currentTimeMillis();
log.debug(
"before {},mapInfo:{}",
simpleName,
log.debug("before {},mapInfo:{}", simpleName,
chatQueryContext.getMapInfo().getDataSetElementMatches());
try {
@@ -43,17 +41,14 @@ public abstract class BaseMapper implements SchemaMapper {
}
long cost = System.currentTimeMillis() - startTime;
log.debug(
"after {},cost:{},mapInfo:{}",
simpleName,
cost,
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost,
chatQueryContext.getMapInfo().getDataSetElementMatches());
}
public abstract void doMap(ChatQueryContext chatQueryContext);
public void addToSchemaMap(
SchemaMapInfo schemaMap, Long dataSetId, SchemaElementMatch newElementMatch) {
public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId,
SchemaElementMatch newElementMatch) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
schemaMap.getDataSetElementMatches();
List<SchemaElementMatch> schemaElementMatches =
@@ -61,26 +56,24 @@ public abstract class BaseMapper implements SchemaMapper {
AtomicBoolean shouldAddNew = new AtomicBoolean(true);
schemaElementMatches.removeIf(
existingElementMatch -> {
if (isEquals(existingElementMatch, newElementMatch)) {
if (newElementMatch.getSimilarity()
> existingElementMatch.getSimilarity()) {
return true;
} else {
shouldAddNew.set(false);
}
}
return false;
});
schemaElementMatches.removeIf(existingElementMatch -> {
if (isEquals(existingElementMatch, newElementMatch)) {
if (newElementMatch.getSimilarity() > existingElementMatch.getSimilarity()) {
return true;
} else {
shouldAddNew.set(false);
}
}
return false;
});
if (shouldAddNew.get()) {
schemaElementMatches.add(newElementMatch);
}
}
private static boolean isEquals(
SchemaElementMatch existElementMatch, SchemaElementMatch newElementMatch) {
private static boolean isEquals(SchemaElementMatch existElementMatch,
SchemaElementMatch newElementMatch) {
SchemaElement existElement = existElementMatch.getElement();
SchemaElement newElement = newElementMatch.getElement();
if (!existElement.equals(newElement)) {
@@ -92,11 +85,8 @@ public abstract class BaseMapper implements SchemaMapper {
return true;
}
public SchemaElement getSchemaElement(
Long dataSetId,
SchemaElementType elementType,
Long elementID,
SemanticSchema semanticSchema) {
public SchemaElement getSchemaElement(Long dataSetId, SchemaElementType elementType,
Long elementID, SemanticSchema semanticSchema) {
SchemaElement element = new SchemaElement();
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
if (Objects.isNull(dataSetSchema)) {
@@ -124,8 +114,8 @@ public abstract class BaseMapper implements SchemaMapper {
return element.getAlias();
}
public <T> List<T> getMatches(
ChatQueryContext chatQueryContext, BaseMatchStrategy matchStrategy) {
public <T> List<T> getMatches(ChatQueryContext chatQueryContext,
BaseMatchStrategy matchStrategy) {
String queryText = chatQueryContext.getQueryText();
List<S2Term> terms =
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
@@ -136,11 +126,9 @@ public abstract class BaseMapper implements SchemaMapper {
if (Objects.isNull(matchResult)) {
return matches;
}
Optional<List<T>> first =
matchResult.entrySet().stream()
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
.map(entry -> entry.getValue())
.findFirst();
Optional<List<T>> first = matchResult.entrySet().stream()
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
.map(entry -> entry.getValue()).findFirst();
if (first.isPresent()) {
matches = first.get();

View File

@@ -19,8 +19,8 @@ import java.util.Set;
@Slf4j
public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStrategy<T> {
@Override
public Map<MatchText, List<T>> match(
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
String text = chatQueryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null;
@@ -35,8 +35,8 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
return result;
}
public List<T> detect(
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
throw new RuntimeException("Not implemented");
}
@@ -46,15 +46,13 @@ public abstract class BaseMatchStrategy<T extends MapResult> implements MatchStr
}
for (T oneRoundResult : oneRoundResults) {
if (existResults.contains(oneRoundResult)) {
boolean isDeleted =
existResults.removeIf(
existResult -> {
boolean delete = existResult.lessSimilar(oneRoundResult);
if (delete) {
log.info("deleted existResult:{}", existResult);
}
return delete;
});
boolean isDeleted = existResults.removeIf(existResult -> {
boolean delete = existResult.lessSimilar(oneRoundResult);
if (delete) {
log.info("deleted existResult:{}", existResult);
}
return delete;
});
if (isDeleted) {
log.info("deleted, add oneRoundResult:{}", oneRoundResult);
existResults.add(oneRoundResult);

View File

@@ -15,22 +15,21 @@ import java.util.Set;
@Slf4j
public abstract class BatchMatchStrategy<T extends MapResult> extends BaseMatchStrategy<T> {
@Autowired protected MapperConfig mapperConfig;
@Autowired
protected MapperConfig mapperConfig;
@Override
public List<T> detect(
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
String text = chatQueryContext.getQueryText();
Set<String> detectSegments = new HashSet<>();
int embeddingTextSize =
Integer.valueOf(
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE));
int embeddingTextSize = Integer
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE));
int embeddingTextStep =
Integer.valueOf(
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP));
int embeddingTextStep = Integer
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP));
for (int startIndex = 0; startIndex < text.length(); startIndex += embeddingTextStep) {
int endIndex = Math.min(startIndex + embeddingTextSize, text.length());
@@ -40,8 +39,6 @@ public abstract class BatchMatchStrategy<T extends MapResult> extends BaseMatchS
return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments);
}
public abstract List<T> detectByBatch(
ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds,
Set<String> detectSegments);
public abstract List<T> detectByBatch(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, Set<String> detectSegments);
}

View File

@@ -30,17 +30,14 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
private List<SchemaElement> allElements;
@Override
public Map<MatchText, List<DatabaseMapResult>> match(
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
public Map<MatchText, List<DatabaseMapResult>> match(ChatQueryContext chatQueryContext,
List<S2Term> terms, Set<Long> detectDataSetIds) {
this.allElements = getSchemaElements(chatQueryContext);
return super.match(chatQueryContext, terms, detectDataSetIds);
}
public List<DatabaseMapResult> detectByStep(
ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds,
String detectSegment,
int offset) {
public List<DatabaseMapResult> detectByStep(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, String detectSegment, int offset) {
if (StringUtils.isBlank(detectSegment)) {
return new ArrayList<>();
}
@@ -56,13 +53,9 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
}
Set<SchemaElement> schemaElements = entry.getValue();
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
schemaElements =
schemaElements.stream()
.filter(
schemaElement ->
detectDataSetIds.contains(
schemaElement.getDataSetId()))
.collect(Collectors.toSet());
schemaElements = schemaElements.stream().filter(
schemaElement -> detectDataSetIds.contains(schemaElement.getDataSetId()))
.collect(Collectors.toSet());
}
for (SchemaElement schemaElement : schemaElements) {
DatabaseMapResult databaseMapResult = new DatabaseMapResult();
@@ -86,40 +79,31 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy<DatabaseMapResult
private Double getThreshold(ChatQueryContext chatQueryContext) {
Double threshold =
Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
Double minThreshold =
Double.valueOf(
mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
Double minThreshold = Double
.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
Map<Long, List<SchemaElementMatch>> modelElementMatches =
chatQueryContext.getMapInfo().getDataSetElementMatches();
boolean existElement =
modelElementMatches.entrySet().stream()
.anyMatch(entry -> entry.getValue().size() >= 1);
boolean existElement = modelElementMatches.entrySet().stream()
.anyMatch(entry -> entry.getValue().size() >= 1);
if (!existElement) {
threshold = threshold / 2;
log.debug(
"ModelElementMatches:{},not exist Element threshold reduce by half:{}",
modelElementMatches,
threshold);
log.debug("ModelElementMatches:{},not exist Element threshold reduce by half:{}",
modelElementMatches, threshold);
}
return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum());
}
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
return models.stream()
.collect(
Collectors.toMap(
SchemaElement::getName,
a -> {
Set<SchemaElement> result = new HashSet<>();
result.add(a);
return result;
},
(k1, k2) -> {
k1.addAll(k2);
return k1;
}));
return models.stream().collect(Collectors.toMap(SchemaElement::getName, a -> {
Set<SchemaElement> result = new HashSet<>();
result.add(a);
return result;
}, (k1, k2) -> {
k1.addAll(k2);
return k1;
}));
}
}

View File

@@ -35,23 +35,15 @@ public class EmbeddingMapper extends BaseMapper {
}
SchemaElementType elementType =
SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
SchemaElement schemaElement =
getSchemaElement(
dataSetId,
elementType,
elementId,
chatQueryContext.getSemanticSchema());
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
chatQueryContext.getSemanticSchema());
if (schemaElement == null) {
continue;
}
SchemaElementMatch schemaElementMatch =
SchemaElementMatch.builder()
.element(schemaElement)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.word(matchResult.getName())
.similarity(matchResult.getSimilarity())
.detectWord(matchResult.getDetectWord())
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.word(matchResult.getName()).similarity(matchResult.getSimilarity())
.detectWord(matchResult.getDetectWord()).build();
// 3. add to mapInfo
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
}

View File

@@ -35,21 +35,18 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING
@Slf4j
public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult> {
@Autowired private MetaEmbeddingService metaEmbeddingService;
@Autowired
private MetaEmbeddingService metaEmbeddingService;
@Override
public List<EmbeddingResult> detectByBatch(
ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds,
Set<String> detectSegments) {
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, Set<String> detectSegments) {
Set<EmbeddingResult> results = new HashSet<>();
int embeddingMapperBatch =
Integer.valueOf(
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
int embeddingMapperBatch = Integer
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
List<String> queryTextsList =
detectSegments.stream()
.map(detectSegment -> detectSegment.trim())
detectSegments.stream().map(detectSegment -> detectSegment.trim())
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment))
.collect(Collectors.toList());
@@ -64,20 +61,15 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
return new ArrayList<>(results);
}
private List<EmbeddingResult> detectByQueryTextsSub(
Set<Long> detectDataSetIds,
List<String> queryTextsSub,
ChatQueryContext chatQueryContext) {
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
List<String> queryTextsSub, ChatQueryContext chatQueryContext) {
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
double embeddingThreshold =
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
double embeddingThresholdMin =
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
double threshold =
getThreshold(
embeddingThreshold,
embeddingThresholdMin,
chatQueryContext.getMapModeEnum());
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin,
chatQueryContext.getMapModeEnum());
// step1. build query params
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
@@ -85,75 +77,45 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
// step2. retrieveQuery by detectSegment
int embeddingNumber =
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
List<RetrieveQueryResult> retrieveQueryResults =
metaEmbeddingService.retrieveQuery(
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return new ArrayList<>();
}
// step3. build EmbeddingResults
List<EmbeddingResult> collect =
retrieveQueryResults.stream()
.map(
retrieveQueryResult -> {
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
if (CollectionUtils.isNotEmpty(retrievals)) {
retrievals.removeIf(
retrieval -> {
if (!retrieveQueryResult
.getQuery()
.contains(retrieval.getQuery())) {
return retrieval.getSimilarity()
< threshold;
}
return false;
});
}
return retrieveQueryResult;
})
.filter(
retrieveQueryResult ->
CollectionUtils.isNotEmpty(
retrieveQueryResult.getRetrieval()))
.flatMap(
retrieveQueryResult ->
retrieveQueryResult.getRetrieval().stream()
.map(
retrieval -> {
EmbeddingResult embeddingResult =
new EmbeddingResult();
BeanUtils.copyProperties(
retrieval, embeddingResult);
embeddingResult.setDetectWord(
retrieveQueryResult.getQuery());
embeddingResult.setName(
retrieval.getQuery());
Map<String, String> convertedMap =
retrieval.getMetadata()
.entrySet().stream()
.collect(
Collectors
.toMap(
Map
.Entry
::getKey,
entry ->
entry.getValue()
.toString()));
embeddingResult.setMetadata(
convertedMap);
return embeddingResult;
}))
.collect(Collectors.toList());
List<EmbeddingResult> collect = retrieveQueryResults.stream().map(retrieveQueryResult -> {
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
if (CollectionUtils.isNotEmpty(retrievals)) {
retrievals.removeIf(retrieval -> {
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
return retrieval.getSimilarity() < threshold;
}
return false;
});
}
return retrieveQueryResult;
}).filter(retrieveQueryResult -> CollectionUtils
.isNotEmpty(retrieveQueryResult.getRetrieval()))
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
.map(retrieval -> {
EmbeddingResult embeddingResult = new EmbeddingResult();
BeanUtils.copyProperties(retrieval, embeddingResult);
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
embeddingResult.setName(retrieval.getQuery());
Map<String, String> convertedMap = retrieval.getMetadata().entrySet()
.stream().collect(Collectors.toMap(Map.Entry::getKey,
entry -> entry.getValue().toString()));
embeddingResult.setMetadata(convertedMap);
return embeddingResult;
}))
.collect(Collectors.toList());
// step4. select mapResul in one round
int embeddingRoundNumber =
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
return collect.stream()
.sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity))
.limit(roundNumber)
.collect(Collectors.toList());
return collect.stream().sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity))
.limit(roundNumber).collect(Collectors.toList());
}
}

View File

@@ -31,19 +31,16 @@ public class EntityMapper extends BaseMapper {
if (entity == null || entity.getId() == null) {
continue;
}
List<SchemaElementMatch> valueSchemaElements =
schemaElementMatchList.stream()
.filter(
schemaElementMatch ->
SchemaElementType.VALUE.equals(
schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
.filter(schemaElementMatch -> SchemaElementType.VALUE
.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
continue;
}
if (!checkExistSameEntitySchemaElements(
schemaElementMatch, schemaElementMatchList)) {
if (!checkExistSameEntitySchemaElements(schemaElementMatch,
schemaElementMatchList)) {
SchemaElementMatch entitySchemaElementMath = new SchemaElementMatch();
BeanUtils.copyProperties(schemaElementMatch, entitySchemaElementMath);
entitySchemaElementMath.setElement(entity);
@@ -54,20 +51,14 @@ public class EntityMapper extends BaseMapper {
}
}
private boolean checkExistSameEntitySchemaElements(
SchemaElementMatch valueSchemaElementMatch,
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
List<SchemaElementMatch> schemaElementMatchList) {
List<SchemaElementMatch> entitySchemaElements =
schemaElementMatchList.stream()
.filter(
schemaElementMatch ->
SchemaElementType.ENTITY.equals(
schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream()
.filter(schemaElementMatch -> SchemaElementType.ENTITY
.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : entitySchemaElements) {
if (schemaElementMatch
.getElement()
.getId()
if (schemaElementMatch.getElement().getId()
.equals(valueSchemaElementMatch.getElement().getId())) {
return true;
}

View File

@@ -26,35 +26,23 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.MAPPER_DI
@Slf4j
public class HanlpDictMatchStrategy extends SingleMatchStrategy<HanlpMapResult> {
@Autowired private KnowledgeBaseService knowledgeBaseService;
@Autowired
private KnowledgeBaseService knowledgeBaseService;
public List<HanlpMapResult> detectByStep(
ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds,
String detectSegment,
int offset) {
public List<HanlpMapResult> detectByStep(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, String detectSegment, int offset) {
// step1. pre search
Integer oneDetectionMaxSize =
Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_MAX_SIZE));
LinkedHashSet<HanlpMapResult> hanlpMapResults =
knowledgeBaseService
.prefixSearch(
detectSegment,
oneDetectionMaxSize,
chatQueryContext.getModelIdToDataSetIds(),
detectDataSetIds)
.stream()
.collect(Collectors.toCollection(LinkedHashSet::new));
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeBaseService
.prefixSearch(detectSegment, oneDetectionMaxSize,
chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds)
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
// step2. suffix search
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults =
knowledgeBaseService
.suffixSearch(
detectSegment,
oneDetectionMaxSize,
chatQueryContext.getModelIdToDataSetIds(),
detectDataSetIds)
.stream()
.collect(Collectors.toCollection(LinkedHashSet::new));
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService
.suffixSearch(detectSegment, oneDetectionMaxSize,
chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds)
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
hanlpMapResults.addAll(suffixHanlpMapResults);
@@ -62,40 +50,28 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy<HanlpMapResult>
return new ArrayList<>();
}
// step3. merge pre/suffix result
hanlpMapResults =
hanlpMapResults.stream()
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toCollection(LinkedHashSet::new));
hanlpMapResults = hanlpMapResults.stream()
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
.collect(Collectors.toCollection(LinkedHashSet::new));
// step4. filter by similarity
hanlpMapResults =
hanlpMapResults.stream()
.filter(
term ->
term.getSimilarity()
>= getThresholdMatch(
term.getNatures(), chatQueryContext))
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
.map(
parseResult -> {
parseResult.setOffset(offset);
return parseResult;
})
.collect(Collectors.toCollection(LinkedHashSet::new));
hanlpMapResults = hanlpMapResults.stream()
.filter(term -> term.getSimilarity() >= getThresholdMatch(term.getNatures(),
chatQueryContext))
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures())).map(parseResult -> {
parseResult.setOffset(offset);
return parseResult;
}).collect(Collectors.toCollection(LinkedHashSet::new));
log.debug(
"detectSegment:{},after isSimilarity parseResults:{}",
detectSegment,
log.debug("detectSegment:{},after isSimilarity parseResults:{}", detectSegment,
hanlpMapResults);
// step5. take only M dimensionValue or N-M metric/dimension value per rond.
int oneDetectionValueSize =
Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE));
List<HanlpMapResult> dimensionValues =
hanlpMapResults.stream()
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
.limit(oneDetectionValueSize)
.collect(Collectors.toList());
List<HanlpMapResult> dimensionValues = hanlpMapResults.stream()
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
.limit(oneDetectionValueSize).collect(Collectors.toList());
Integer oneDetectionSize =
Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_SIZE));
@@ -108,14 +84,10 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy<HanlpMapResult>
// fill the rest of the list with other results, excluding the dimensionValue if it was
// added
if (oneRoundResults.size() < oneDetectionSize) {
List<HanlpMapResult> additionalResults =
hanlpMapResults.stream()
.filter(
entry ->
!mapperHelper.existDimensionValues(entry.getNatures())
&& !oneRoundResults.contains(entry))
.limit(oneDetectionSize - oneRoundResults.size())
.collect(Collectors.toList());
List<HanlpMapResult> additionalResults = hanlpMapResults.stream()
.filter(entry -> !mapperHelper.existDimensionValues(entry.getNatures())
&& !oneRoundResults.contains(entry))
.limit(oneDetectionSize - oneRoundResults.size()).collect(Collectors.toList());
oneRoundResults.addAll(additionalResults);
}
return oneRoundResults;
@@ -124,17 +96,13 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy<HanlpMapResult>
public double getThresholdMatch(List<String> natures, ChatQueryContext chatQueryContext) {
Double threshold =
Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
Double minThreshold =
Double.valueOf(
mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
Double minThreshold = Double
.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
if (mapperHelper.existDimensionValues(natures)) {
threshold =
Double.valueOf(
mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD));
minThreshold =
Double.valueOf(
mapperConfig.getParameterValue(
MapperConfig.MAPPER_VALUE_THRESHOLD_MIN));
threshold = Double
.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD));
minThreshold = Double.valueOf(
mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD_MIN));
}
return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum());

View File

@@ -51,21 +51,15 @@ public class KeywordMapper extends BaseMapper {
convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults);
}
private void convertHanlpMapResultToMapInfo(
List<HanlpMapResult> mapResults,
ChatQueryContext chatQueryContext,
List<S2Term> terms) {
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults,
ChatQueryContext chatQueryContext, List<S2Term> terms) {
if (CollectionUtils.isEmpty(mapResults)) {
return;
}
HanlpHelper.transLetterOriginal(mapResults);
Map<String, Long> wordNatureToFrequency =
terms.stream()
.collect(
Collectors.toMap(
entry -> entry.getWord() + entry.getNature(),
term -> Long.valueOf(term.getFrequency()),
(value1, value2) -> value2));
Map<String, Long> wordNatureToFrequency = terms.stream()
.collect(Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
for (HanlpMapResult hanlpMapResult : mapResults) {
for (String nature : hanlpMapResult.getNatures()) {
@@ -78,32 +72,24 @@ public class KeywordMapper extends BaseMapper {
continue;
}
Long elementID = NatureHelper.getElementID(nature);
SchemaElement element =
getSchemaElement(
dataSetId,
elementType,
elementID,
chatQueryContext.getSemanticSchema());
SchemaElement element = getSchemaElement(dataSetId, elementType, elementID,
chatQueryContext.getSemanticSchema());
if (element == null) {
continue;
}
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
SchemaElementMatch schemaElementMatch =
SchemaElementMatch.builder()
.element(element)
.frequency(frequency)
.word(hanlpMapResult.getName())
.similarity(hanlpMapResult.getSimilarity())
.detectWord(hanlpMapResult.getDetectWord())
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element).frequency(frequency).word(hanlpMapResult.getName())
.similarity(hanlpMapResult.getSimilarity())
.detectWord(hanlpMapResult.getDetectWord()).build();
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
}
}
}
private void convertDatabaseMapResultToMapInfo(
ChatQueryContext chatQueryContext, List<DatabaseMapResult> mapResults) {
private void convertDatabaseMapResultToMapInfo(ChatQueryContext chatQueryContext,
List<DatabaseMapResult> mapResults) {
for (DatabaseMapResult match : mapResults) {
SchemaElement schemaElement = match.getSchemaElement();
Set<Long> regElementSet =
@@ -111,20 +97,14 @@ public class KeywordMapper extends BaseMapper {
if (regElementSet.contains(schemaElement.getId())) {
continue;
}
SchemaElementMatch schemaElementMatch =
SchemaElementMatch.builder()
.element(schemaElement)
.word(schemaElement.getName())
.detectWord(match.getDetectWord())
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.similarity(
EditDistanceUtils.getSimilarity(
match.getDetectWord(), schemaElement.getName()))
.build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(schemaElement).word(schemaElement.getName())
.detectWord(match.getDetectWord()).frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.similarity(EditDistanceUtils.getSimilarity(match.getDetectWord(),
schemaElement.getName()))
.build();
log.info("add to schema, elementMatch {}", schemaElementMatch);
addToSchemaMap(
chatQueryContext.getMapInfo(),
schemaElement.getDataSetId(),
addToSchemaMap(chatQueryContext.getMapInfo(), schemaElement.getDataSetId(),
schemaElementMatch);
}
}
@@ -135,13 +115,9 @@ public class KeywordMapper extends BaseMapper {
if (CollectionUtils.isEmpty(elements)) {
return new HashSet<>();
}
return elements.stream()
.filter(
elementMatch ->
SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|| SchemaElementType.DIMENSION.equals(
elementMatch.getElement().getType()))
.map(elementMatch -> elementMatch.getElement().getId())
.collect(Collectors.toSet());
return elements.stream().filter(
elementMatch -> SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
.map(elementMatch -> elementMatch.getElement().getId()).collect(Collectors.toSet());
}
}

View File

@@ -26,19 +26,16 @@ public class MapFilter {
filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0));
break;
case METRIC:
filterByQueryDataType(
chatQueryContext,
filterByQueryDataType(chatQueryContext,
element -> !SchemaElementType.METRIC.equals(element.getType()));
break;
case DIMENSION:
filterByQueryDataType(
chatQueryContext,
element -> {
boolean isDimensionOrValue =
SchemaElementType.DIMENSION.equals(element.getType())
|| SchemaElementType.VALUE.equals(element.getType());
return !isDimensionOrValue;
});
filterByQueryDataType(chatQueryContext, element -> {
boolean isDimensionOrValue =
SchemaElementType.DIMENSION.equals(element.getType())
|| SchemaElementType.VALUE.equals(element.getType());
return !isDimensionOrValue;
});
break;
case ALL:
default:
@@ -67,31 +64,28 @@ public class MapFilter {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
List<SchemaElementMatch> value = entry.getValue();
if (!CollectionUtils.isEmpty(value)) {
value.removeIf(
schemaElementMatch ->
StringUtils.length(schemaElementMatch.getDetectWord()) <= 1);
value.removeIf(schemaElementMatch -> StringUtils
.length(schemaElementMatch.getDetectWord()) <= 1);
}
}
}
public static void filterByQueryDataType(
ChatQueryContext chatQueryContext, Predicate<SchemaElement> needRemovePredicate) {
public static void filterByQueryDataType(ChatQueryContext chatQueryContext,
Predicate<SchemaElement> needRemovePredicate) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
chatQueryContext.getMapInfo().getDataSetElementMatches();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
List<SchemaElementMatch> schemaElementMatches = entry.getValue();
schemaElementMatches.removeIf(
schemaElementMatch -> {
SchemaElement element = schemaElementMatch.getElement();
SchemaElementType type = element.getType();
schemaElementMatches.removeIf(schemaElementMatch -> {
SchemaElement element = schemaElementMatch.getElement();
SchemaElementType type = element.getType();
boolean isEntityOrDatasetOrId =
SchemaElementType.ENTITY.equals(type)
|| SchemaElementType.DATASET.equals(type)
|| SchemaElementType.ID.equals(type);
boolean isEntityOrDatasetOrId = SchemaElementType.ENTITY.equals(type)
|| SchemaElementType.DATASET.equals(type)
|| SchemaElementType.ID.equals(type);
return !isEntityOrDatasetOrId && needRemovePredicate.test(element);
});
return !isEntityOrDatasetOrId && needRemovePredicate.test(element);
});
}
}
@@ -116,21 +110,16 @@ public class MapFilter {
List<SchemaElementMatch> group = entry.getValue();
// Filter out objects with similarity=1.0
List<SchemaElementMatch> fullMatches =
group.stream()
.filter(SchemaElementMatch::isFullMatched)
.collect(Collectors.toList());
List<SchemaElementMatch> fullMatches = group.stream()
.filter(SchemaElementMatch::isFullMatched).collect(Collectors.toList());
if (!fullMatches.isEmpty()) {
// If there are objects with similarity=1.0, choose the one with the longest
// detectWord and smallest offset
SchemaElementMatch bestMatch =
fullMatches.stream()
.max(
Comparator.comparing(
(SchemaElementMatch match) ->
match.getDetectWord().length()))
.orElse(null);
SchemaElementMatch bestMatch = fullMatches.stream()
.max(Comparator.comparing(
(SchemaElementMatch match) -> match.getDetectWord().length()))
.orElse(null);
if (bestMatch != null) {
result.add(bestMatch);
}
@@ -145,8 +134,7 @@ public class MapFilter {
public static void filterInExactMatch(List<SchemaElementMatch> matches) {
Map<String, List<SchemaElementMatch>> fullMatches =
matches.stream()
.filter(schemaElementMatch -> schemaElementMatch.isFullMatched())
matches.stream().filter(schemaElementMatch -> schemaElementMatch.isFullMatched())
.collect(Collectors.groupingBy(SchemaElementMatch::getWord));
Set<String> keys = new HashSet<>(fullMatches.keySet());
for (String key1 : keys) {
@@ -157,8 +145,7 @@ public class MapFilter {
}
}
List<SchemaElementMatch> notFullMatches =
matches.stream()
.filter(schemaElementMatch -> !schemaElementMatch.isFullMatched())
matches.stream().filter(schemaElementMatch -> !schemaElementMatch.isFullMatched())
.collect(Collectors.toList());
List<SchemaElementMatch> mergedMatches = new ArrayList<>();

View File

@@ -7,129 +7,58 @@ import org.springframework.stereotype.Service;
@Service("HeadlessMapperConfig")
public class MapperConfig extends ParameterConfig {
public static final Parameter MAPPER_DETECTION_SIZE =
new Parameter(
"s2.mapper.detection.size",
"8",
"一次探测返回结果个数",
"在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数",
"number",
"Mapper相关配置");
public static final Parameter MAPPER_DETECTION_SIZE = new Parameter("s2.mapper.detection.size",
"8", "一次探测返回结果个数", "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数", "number", "Mapper相关配置");
public static final Parameter MAPPER_DETECTION_MAX_SIZE =
new Parameter(
"s2.mapper.detection.max.size",
"20",
"一次探测前后缀匹配结果返回个数",
"单次前后缀匹配返回的结果个数",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.detection.max.size", "20", "一次探测前后缀匹配结果返回个数", "单次前后缀匹配返回的结果个数",
"number", "Mapper相关配置");
public static final Parameter MAPPER_NAME_THRESHOLD =
new Parameter(
"s2.mapper.name.threshold",
"0.5",
"指标名、维度名文本相似度阈值",
"文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.name.threshold", "0.5", "指标名、维度名文本相似度阈值",
"文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", "number", "Mapper相关配置");
public static final Parameter MAPPER_NAME_THRESHOLD_MIN =
new Parameter(
"s2.mapper.name.min.threshold",
"0.25",
"指标名、维度名最小文本相似度阈值",
"指标名、维度名相似度阈值在动态调整中的最低值",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.name.min.threshold", "0.25", "指标名、维度名最小文本相似度阈值",
"指标名、维度名相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
public static final Parameter MAPPER_DIMENSION_VALUE_SIZE =
new Parameter(
"s2.mapper.value.size",
"1",
"一次探测返回维度值结果个数",
"在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的维度值结果个数",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.value.size", "1", "一次探测返回维度值结果个数",
"在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的维度值结果个数", "number", "Mapper相关配置");
public static final Parameter MAPPER_VALUE_THRESHOLD =
new Parameter(
"s2.mapper.value.threshold",
"0.5",
"维度值文本相似度阈值",
"文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.value.threshold", "0.5", "维度值文本相似度阈值",
"文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", "number", "Mapper相关配置");
public static final Parameter MAPPER_VALUE_THRESHOLD_MIN =
new Parameter(
"s2.mapper.value.min.threshold",
"0.3",
"维度值最小文本相似度阈值",
"维度值相似度阈值在动态调整中的最低值",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.value.min.threshold", "0.3", "维度值最小文本相似度阈值",
"维度值相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_TEXT_SIZE =
new Parameter(
"s2.mapper.embedding.word.size",
"4",
"用于向量召回文本长度",
"为提高向量召回效率, 按指定长度进行向量语义召回",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.embedding.word.size", "4", "用于向量召回文本长度",
"为提高向量召回效率, 按指定长度进行向量语义召回", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_TEXT_STEP =
new Parameter(
"s2.mapper.embedding.word.step",
"3",
"向量召回文本每步长度",
"为提高向量召回效率, 按指定每步长度进行召回",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.embedding.word.step", "3", "向量召回文本每步长度",
"为提高向量召回效率, 按指定每步长度进行召回", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_BATCH =
new Parameter(
"s2.mapper.embedding.batch",
"50",
"批量向量召回文本请求个数",
"每次进行向量语义召回的原始文本片段个数",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.embedding.batch", "50", "批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_NUMBER =
new Parameter(
"s2.mapper.embedding.number",
"5",
"批量向量召回文本返回结果个数",
"每个文本进行向量语义召回的文本结果个数",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.embedding.number", "5", "批量向量召回文本返回结果个数",
"每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
new Parameter(
"s2.mapper.embedding.threshold",
"0.98",
"向量召回相似度阈值",
"相似度小于该阈值的则舍弃",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.embedding.threshold", "0.98", "向量召回相似度阈值", "相似度小于该阈值的则舍弃",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_THRESHOLD_MIN =
new Parameter(
"s2.mapper.embedding.min.threshold",
"0.9",
"向量召回最小相似度阈值",
"向量召回相似度阈值在动态调整中的最低值",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.embedding.min.threshold", "0.9", "向量召回最小相似度阈值",
"向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
new Parameter(
"s2.mapper.embedding.round.number",
"10",
"向量召回最小相似度阈值",
"向量召回相似度阈值在动态调整中的最低值",
"number",
"Mapper相关配置");
new Parameter("s2.mapper.embedding.round.number", "10", "向量召回最小相似度阈值",
"向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
}

View File

@@ -28,11 +28,8 @@ public class MapperHelper {
}
public Integer getStepOffset(List<S2Term> termList, Integer index) {
List<Integer> offsetList =
termList.stream()
.sorted(Comparator.comparing(S2Term::getOffset))
.map(term -> term.getOffset())
.collect(Collectors.toList());
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(S2Term::getOffset))
.map(term -> term.getOffset()).collect(Collectors.toList());
for (int j = 0; j < termList.size() - 1; j++) {
if (offsetList.get(j) <= index && offsetList.get(j + 1) > index) {
@@ -43,13 +40,8 @@ public class MapperHelper {
}
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
return terms.stream()
.sorted(Comparator.comparing(S2Term::length))
.collect(
Collectors.toMap(
S2Term::getOffset,
term -> term.word.length(),
(value1, value2) -> value2));
return terms.stream().sorted(Comparator.comparing(S2Term::length)).collect(Collectors
.toMap(S2Term::getOffset, term -> term.word.length(), (value1, value2) -> value2));
}
/**

View File

@@ -13,6 +13,6 @@ import java.util.Set;
*/
public interface MatchStrategy<T extends MapResult> {
Map<MatchText, List<T>> match(
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds);
}

View File

@@ -43,17 +43,15 @@ public class QueryFilterMapper extends BaseMapper {
}
private void clearOtherSchemaElementMatch(Set<Long> viewIds, SchemaMapInfo schemaMapInfo) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry :
schemaMapInfo.getDataSetElementMatches().entrySet()) {
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo
.getDataSetElementMatches().entrySet()) {
if (!viewIds.contains(entry.getKey())) {
entry.getValue().clear();
}
}
}
private void addValueSchemaElementMatch(
Long dataSetId,
ChatQueryContext chatQueryContext,
private void addValueSchemaElementMatch(Long dataSetId, ChatQueryContext chatQueryContext,
List<SchemaElementMatch> candidateElementMatches) {
QueryFilters queryFilters = chatQueryContext.getQueryFilters();
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
@@ -63,40 +61,27 @@ public class QueryFilterMapper extends BaseMapper {
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
continue;
}
SchemaElement element =
SchemaElement.builder()
.id(filter.getElementID())
.name(String.valueOf(filter.getValue()))
.type(SchemaElementType.VALUE)
.bizName(filter.getBizName())
.dataSetId(dataSetId)
.build();
SchemaElementMatch schemaElementMatch =
SchemaElementMatch.builder()
.element(element)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.word(String.valueOf(filter.getValue()))
.similarity(similarity)
.detectWord(Constants.EMPTY)
.build();
SchemaElement element = SchemaElement.builder().id(filter.getElementID())
.name(String.valueOf(filter.getValue())).type(SchemaElementType.VALUE)
.bizName(filter.getBizName()).dataSetId(dataSetId).build();
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder().element(element)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.word(String.valueOf(filter.getValue())).similarity(similarity)
.detectWord(Constants.EMPTY).build();
candidateElementMatches.add(schemaElementMatch);
}
chatQueryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches);
}
private boolean checkExistSameValueSchemaElementMatch(
QueryFilter queryFilter, List<SchemaElementMatch> schemaElementMatches) {
List<SchemaElementMatch> valueSchemaElements =
schemaElementMatches.stream()
.filter(
schemaElementMatch ->
SchemaElementType.VALUE.equals(
schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
List<SchemaElementMatch> schemaElementMatches) {
List<SchemaElementMatch> valueSchemaElements = schemaElementMatches.stream()
.filter(schemaElementMatch -> SchemaElementType.VALUE
.equals(schemaElementMatch.getElement().getType()))
.collect(Collectors.toList());
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
if (schemaElementMatch.getElement().getId().equals(queryFilter.getElementID())
&& schemaElementMatch
.getWord()
&& schemaElementMatch.getWord()
.equals(String.valueOf(queryFilter.getValue()))) {
return true;
}

View File

@@ -27,19 +27,21 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
private static final int SEARCH_SIZE = 3;
@Autowired private KnowledgeBaseService knowledgeBaseService;
@Autowired
private KnowledgeBaseService knowledgeBaseService;
@Autowired private MapperHelper mapperHelper;
@Autowired
private MapperHelper mapperHelper;
@Override
public Map<MatchText, List<HanlpMapResult>> match(
ChatQueryContext chatQueryContext, List<S2Term> originals, Set<Long> detectDataSetIds) {
public Map<MatchText, List<HanlpMapResult>> match(ChatQueryContext chatQueryContext,
List<S2Term> originals, Set<Long> detectDataSetIds) {
String text = chatQueryContext.getQueryText();
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(originals);
List<Integer> detectIndexList = Lists.newArrayList();
for (Integer index = 0; index < text.length(); ) {
for (Integer index = 0; index < text.length();) {
if (index < text.length()) {
detectIndexList.add(index);
@@ -52,58 +54,33 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
}
}
Map<MatchText, List<HanlpMapResult>> regTextMap = new ConcurrentHashMap<>();
detectIndexList.stream()
.parallel()
.forEach(
detectIndex -> {
String regText = text.substring(0, detectIndex);
String detectSegment = text.substring(detectIndex);
detectIndexList.stream().parallel().forEach(detectIndex -> {
String regText = text.substring(0, detectIndex);
String detectSegment = text.substring(detectIndex);
if (StringUtils.isNotEmpty(detectSegment)) {
List<HanlpMapResult> hanlpMapResults =
knowledgeBaseService.prefixSearch(
detectSegment,
SearchService.SEARCH_SIZE,
chatQueryContext.getModelIdToDataSetIds(),
detectDataSetIds);
List<HanlpMapResult> suffixHanlpMapResults =
knowledgeBaseService.suffixSearch(
detectSegment,
SEARCH_SIZE,
chatQueryContext.getModelIdToDataSetIds(),
detectDataSetIds);
hanlpMapResults.addAll(suffixHanlpMapResults);
// remove entity name where search
hanlpMapResults =
hanlpMapResults.stream()
.filter(
entry -> {
List<String> natures =
entry.getNatures().stream()
.filter(
nature ->
!nature
.endsWith(
DictWordType
.ENTITY
.getType()))
.collect(
Collectors
.toList());
if (CollectionUtils.isEmpty(natures)) {
return false;
}
return true;
})
.collect(Collectors.toList());
MatchText matchText =
MatchText.builder()
.regText(regText)
.detectSegment(detectSegment)
.build();
regTextMap.put(matchText, hanlpMapResults);
}
});
if (StringUtils.isNotEmpty(detectSegment)) {
List<HanlpMapResult> hanlpMapResults =
knowledgeBaseService.prefixSearch(detectSegment, SearchService.SEARCH_SIZE,
chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds);
List<HanlpMapResult> suffixHanlpMapResults =
knowledgeBaseService.suffixSearch(detectSegment, SEARCH_SIZE,
chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds);
hanlpMapResults.addAll(suffixHanlpMapResults);
// remove entity name where search
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
List<String> natures = entry.getNatures().stream()
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(natures)) {
return false;
}
return true;
}).collect(Collectors.toList());
MatchText matchText =
MatchText.builder().regText(regText).detectSegment(detectSegment).build();
regTextMap.put(matchText, hanlpMapResults);
}
});
return regTextMap;
}
}

View File

@@ -16,20 +16,22 @@ import java.util.Set;
@Service
@Slf4j
public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatchStrategy<T> {
@Autowired protected MapperConfig mapperConfig;
@Autowired protected MapperHelper mapperHelper;
@Autowired
protected MapperConfig mapperConfig;
@Autowired
protected MapperHelper mapperHelper;
public List<T> detect(
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
Map<Integer, Integer> regOffsetToLength = mapperHelper.getRegOffsetToLength(terms);
String text = chatQueryContext.getQueryText();
Set<T> results = new HashSet<>();
Set<String> detectSegments = new HashSet<>();
for (Integer startIndex = 0; startIndex <= text.length() - 1; ) {
for (Integer startIndex = 0; startIndex <= text.length() - 1;) {
for (Integer index = startIndex; index <= text.length(); ) {
for (Integer index = startIndex; index <= text.length();) {
int offset = mapperHelper.getStepOffset(terms, startIndex);
index = mapperHelper.getStepIndex(regOffsetToLength, index);
if (index <= text.length()) {
@@ -45,9 +47,6 @@ public abstract class SingleMatchStrategy<T extends MapResult> extends BaseMatch
return new ArrayList<>(results);
}
public abstract List<T> detectByStep(
ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds,
String detectSegment,
int offset);
public abstract List<T> detectByStep(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, String detectSegment, int offset);
}

View File

@@ -13,89 +13,45 @@ import java.util.List;
public class ParserConfig extends ParameterConfig {
public static final Parameter PARSER_STRATEGY_TYPE =
new Parameter(
"s2.parser.s2sql.strategy",
"ONE_PASS_SELF_CONSISTENCY",
"LLM解析生成S2SQL策略",
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql",
"list",
"Parser相关配置",
new Parameter("s2.parser.s2sql.strategy", "ONE_PASS_SELF_CONSISTENCY", "LLM解析生成S2SQL策略",
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql", "list", "Parser相关配置",
Lists.newArrayList("ONE_PASS_SELF_CONSISTENCY"));
public static final Parameter PARSER_LINKING_VALUE_ENABLE =
new Parameter(
"s2.parser.linking.value.enable",
"true",
"是否将Mapper探测识别到的维度值提供给大模型",
"为了数据安全考虑, 这里可进行开关选择",
"bool",
"Parser相关配置");
new Parameter("s2.parser.linking.value.enable", "true", "是否将Mapper探测识别到的维度值提供给大模型",
"为了数据安全考虑, 这里可进行开关选择", "bool", "Parser相关配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD =
new Parameter(
"s2.parser.text.length.threshold",
"10",
"用户输入文本长短阈值",
"文本超过该阈值为长文本",
"number",
"Parser相关配置");
new Parameter("s2.parser.text.length.threshold", "10", "用户输入文本长短阈值", "文本超过该阈值为长文本",
"number", "Parser相关配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT =
new Parameter(
"s2.parser.text.threshold.short",
"0.5",
"短文本匹配阈值",
new Parameter("s2.parser.text.threshold.short", "0.5", "短文本匹配阈值",
"由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,"
+ "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number",
"Parser相关配置");
"number", "Parser相关配置");
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG =
new Parameter(
"s2.parser.text.threshold.long",
"0.8",
"长文本匹配阈值",
"如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
"number",
"Parser相关配置");
new Parameter("s2.parser.text.threshold.long", "0.8", "长文本匹配阈值",
"如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "Parser相关配置");
public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER =
new Parameter(
"s2.parser.exemplar-recall.number",
"10",
"exemplar召回个数",
"",
"number",
"Parser相关配置");
public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER = new Parameter(
"s2.parser.exemplar-recall.number", "10", "exemplar召回个数", "", "number", "Parser相关配置");
public static final Parameter PARSER_FEW_SHOT_NUMBER =
new Parameter(
"s2.parser.few-shot.number",
"3",
"few-shot样例个数",
"样例越多效果可能越好但token消耗越大",
"number",
"Parser相关配置");
new Parameter("s2.parser.few-shot.number", "3", "few-shot样例个数", "样例越多效果可能越好但token消耗越大",
"number", "Parser相关配置");
public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER =
new Parameter(
"s2.parser.self-consistency.number",
"1",
"self-consistency执行个数",
"执行越多效果可能越好但token消耗越大",
"number",
"Parser相关配置");
new Parameter("s2.parser.self-consistency.number", "1", "self-consistency执行个数",
"执行越多效果可能越好但token消耗越大", "number", "Parser相关配置");
public static final Parameter PARSER_SHOW_COUNT =
new Parameter(
"s2.parser.show.count", "3", "解析结果展示个数", "前端展示的解析个数", "number", "Parser相关配置");
public static final Parameter PARSER_SHOW_COUNT = new Parameter("s2.parser.show.count", "3",
"解析结果展示个数", "前端展示的解析个数", "number", "Parser相关配置");
@Override
public List<Parameter> getSysParameters() {
return Lists.newArrayList(
PARSER_LINKING_VALUE_ENABLE,
PARSER_FEW_SHOT_NUMBER,
PARSER_SELF_CONSISTENCY_NUMBER,
PARSER_SHOW_COUNT);
return Lists.newArrayList(PARSER_LINKING_VALUE_ENABLE, PARSER_FEW_SHOT_NUMBER,
PARSER_SELF_CONSISTENCY_NUMBER, PARSER_SHOW_COUNT);
}
}

View File

@@ -59,10 +59,8 @@ public class QueryTypeParser implements SemanticParser {
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getParsedS2SQL());
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
Set<String> ids =
semanticSchema.getEntities(dataSetId).stream()
.map(SchemaElement::getName)
.collect(Collectors.toSet());
Set<String> ids = semanticSchema.getEntities(dataSetId).stream()
.map(SchemaElement::getName).collect(Collectors.toSet());
if (CollectionUtils.isNotEmpty(ids)
&& ids.stream().anyMatch(whereFilterByTimeFields::contains)) {
return QueryType.ID;
@@ -80,15 +78,14 @@ public class QueryTypeParser implements SemanticParser {
}
private static List<String> filterByTimeFields(List<String> whereFields) {
List<String> selectAndWhereFilterByTimeFields =
whereFields.stream()
.filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
List<String> selectAndWhereFilterByTimeFields = whereFields.stream()
.filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
return selectAndWhereFilterByTimeFields;
}
private static boolean selectContainsMetric(
SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId,
SemanticSchema semanticSchema) {
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
if (CollectionUtils.isNotEmpty(metrics)) {

View File

@@ -50,10 +50,7 @@ public class SatisfactionChecker {
} else if (degree < shortTextLengthThreshold) {
return false;
}
log.info(
"queryMode:{}, degree:{}, parse info:{}",
semanticParseInfo.getQueryMode(),
degree,
log.info("queryMode:{}, degree:{}, parse info:{}", semanticParseInfo.getQueryMode(), degree,
semanticParseInfo);
return true;
}

View File

@@ -37,26 +37,19 @@ public class HeuristicDataSetResolver implements DataSetResolver {
protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) {
Map<Long, DataSetMatchResult> dataSetMatchRet = getDataSetMatchResult(schemaMap);
Entry<Long, DataSetMatchResult> selectedDataset =
dataSetMatchRet.entrySet().stream()
.sorted(
(o1, o2) -> {
double difference =
o1.getValue().getMaxDatesetSimilarity()
- o2.getValue().getMaxDatesetSimilarity();
if (difference == 0) {
difference =
o1.getValue().getMaxMetricSimilarity()
- o2.getValue().getMaxMetricSimilarity();
if (difference == 0) {
difference =
o1.getValue().getTotalSimilarity()
- o2.getValue().getTotalSimilarity();
}
}
return difference >= 0 ? -1 : 1;
})
.findFirst()
.orElse(null);
dataSetMatchRet.entrySet().stream().sorted((o1, o2) -> {
double difference = o1.getValue().getMaxDatesetSimilarity()
- o2.getValue().getMaxDatesetSimilarity();
if (difference == 0) {
difference = o1.getValue().getMaxMetricSimilarity()
- o2.getValue().getMaxMetricSimilarity();
if (difference == 0) {
difference = o1.getValue().getTotalSimilarity()
- o2.getValue().getTotalSimilarity();
}
}
return difference >= 0 ? -1 : 1;
}).findFirst().orElse(null);
if (selectedDataset != null) {
log.info("selectDataSet with multiple DataSets [{}]", selectedDataset.getKey());
return selectedDataset.getKey();
@@ -67,8 +60,8 @@ public class HeuristicDataSetResolver implements DataSetResolver {
protected Map<Long, DataSetMatchResult> getDataSetMatchResult(SchemaMapInfo schemaMap) {
Map<Long, DataSetMatchResult> dateSetMatchRet = new HashMap<>();
for (Entry<Long, List<SchemaElementMatch>> entry :
schemaMap.getDataSetElementMatches().entrySet()) {
for (Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getDataSetElementMatches()
.entrySet()) {
double maxMetricSimilarity = 0;
double maxDatasetSimilarity = 0;
double totalSimilarity = 0;
@@ -81,13 +74,10 @@ public class HeuristicDataSetResolver implements DataSetResolver {
}
totalSimilarity += match.getSimilarity();
}
dateSetMatchRet.put(
entry.getKey(),
DataSetMatchResult.builder()
.maxMetricSimilarity(maxMetricSimilarity)
dateSetMatchRet.put(entry.getKey(),
DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
.maxDatesetSimilarity(maxDatasetSimilarity)
.totalSimilarity(totalSimilarity)
.build());
.totalSimilarity(totalSimilarity).build());
}
return dateSetMatchRet;

View File

@@ -31,7 +31,8 @@ import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_ST
@Service
public class LLMRequestService {
@Autowired private ParserConfig parserConfig;
@Autowired
private ParserConfig parserConfig;
public boolean isSkip(ChatQueryContext queryCtx) {
if (!queryCtx.getText2SQLType().enableLLM()) {
@@ -95,88 +96,63 @@ public class LLMRequestService {
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
}
return matchedElements.stream()
.filter(
schemaElementMatch -> {
SchemaElementType elementType =
schemaElementMatch.getElement().getType();
return SchemaElementType.TERM.equals(elementType);
})
.map(
schemaElementMatch -> {
LLMReq.Term term = new LLMReq.Term();
term.setName(schemaElementMatch.getElement().getName());
term.setDescription(schemaElementMatch.getElement().getDescription());
term.setAlias(schemaElementMatch.getElement().getAlias());
return term;
})
.collect(Collectors.toList());
return matchedElements.stream().filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.TERM.equals(elementType);
}).map(schemaElementMatch -> {
LLMReq.Term term = new LLMReq.Term();
term.setName(schemaElementMatch.getElement().getName());
term.setDescription(schemaElementMatch.getElement().getDescription());
term.setAlias(schemaElementMatch.getElement().getAlias());
return term;
}).collect(Collectors.toList());
}
protected List<LLMReq.ElementValue> getMappedValues(
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
protected List<LLMReq.ElementValue> getMappedValues(@NotNull ChatQueryContext queryCtx,
Long dataSetId) {
List<SchemaElementMatch> matchedElements =
queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
return new ArrayList<>();
}
Set<LLMReq.ElementValue> valueMatches =
matchedElements.stream()
.filter(elementMatch -> !elementMatch.isInherited())
.filter(
schemaElementMatch -> {
SchemaElementType type =
schemaElementMatch.getElement().getType();
return SchemaElementType.VALUE.equals(type)
|| SchemaElementType.ID.equals(type);
})
.map(
elementMatch -> {
LLMReq.ElementValue elementValue = new LLMReq.ElementValue();
elementValue.setFieldName(elementMatch.getElement().getName());
elementValue.setFieldValue(elementMatch.getWord());
return elementValue;
})
.collect(Collectors.toSet());
Set<LLMReq.ElementValue> valueMatches = matchedElements.stream()
.filter(elementMatch -> !elementMatch.isInherited()).filter(schemaElementMatch -> {
SchemaElementType type = schemaElementMatch.getElement().getType();
return SchemaElementType.VALUE.equals(type)
|| SchemaElementType.ID.equals(type);
}).map(elementMatch -> {
LLMReq.ElementValue elementValue = new LLMReq.ElementValue();
elementValue.setFieldName(elementMatch.getElement().getName());
elementValue.setFieldValue(elementMatch.getWord());
return elementValue;
}).collect(Collectors.toSet());
return new ArrayList<>(valueMatches);
}
protected List<SchemaElement> getMappedMetrics(
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
protected List<SchemaElement> getMappedMetrics(@NotNull ChatQueryContext queryCtx,
Long dataSetId) {
List<SchemaElementMatch> matchedElements =
queryCtx.getMapInfo().getMatchedElements(dataSetId);
if (CollectionUtils.isEmpty(matchedElements)) {
return Collections.emptyList();
}
List<SchemaElement> schemaElements =
matchedElements.stream()
.filter(
schemaElementMatch -> {
SchemaElementType elementType =
schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType);
})
.map(
schemaElementMatch -> {
return schemaElementMatch.getElement();
})
.collect(Collectors.toList());
List<SchemaElement> schemaElements = matchedElements.stream().filter(schemaElementMatch -> {
SchemaElementType elementType = schemaElementMatch.getElement().getType();
return SchemaElementType.METRIC.equals(elementType);
}).map(schemaElementMatch -> {
return schemaElementMatch.getElement();
}).collect(Collectors.toList());
return schemaElements;
}
protected List<SchemaElement> getMappedDimensions(
@NotNull ChatQueryContext queryCtx, Long dataSetId) {
protected List<SchemaElement> getMappedDimensions(@NotNull ChatQueryContext queryCtx,
Long dataSetId) {
List<SchemaElementMatch> matchedElements =
queryCtx.getMapInfo().getMatchedElements(dataSetId);
List<SchemaElement> dimensionElements =
matchedElements.stream()
.filter(
element ->
SchemaElementType.DIMENSION.equals(
element.getElement().getType()))
.map(SchemaElementMatch::getElement)
.collect(Collectors.toList());
List<SchemaElement> dimensionElements = matchedElements.stream().filter(
element -> SchemaElementType.DIMENSION.equals(element.getElement().getType()))
.map(SchemaElementMatch::getElement).collect(Collectors.toList());
return new ArrayList<>(dimensionElements);
}

View File

@@ -23,8 +23,8 @@ import java.util.Objects;
@Service
public class LLMResponseService {
public SemanticParseInfo addParseInfo(
ChatQueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) {
public SemanticParseInfo addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult,
String s2SQL, Double weight) {
if (Objects.isNull(weight)) {
weight = 0D;
}
@@ -33,20 +33,16 @@ public class LLMResponseService {
parseInfo.setDataSet(queryCtx.getSemanticSchema().getDataSet(parseResult.getDataSetId()));
parseInfo.setQueryConfig(
queryCtx.getSemanticSchema().getQueryConfig(parseResult.getDataSetId()));
parseInfo
.getElementMatches()
parseInfo.getElementMatches()
.addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getDataSetId()));
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, parseResult);
properties.put("type", "internal");
Text2SQLExemplar exemplar =
Text2SQLExemplar.builder()
.question(queryCtx.getQueryText())
.sideInfo(parseResult.getLlmResp().getSideInfo())
.dbSchema(parseResult.getLlmResp().getSchema())
.sql(parseResult.getLlmResp().getSqlOutput())
.build();
Text2SQLExemplar exemplar = Text2SQLExemplar.builder().question(queryCtx.getQueryText())
.sideInfo(parseResult.getLlmResp().getSideInfo())
.dbSchema(parseResult.getLlmResp().getSchema())
.sql(parseResult.getLlmResp().getSqlOutput()).build();
properties.put(Text2SQLExemplar.PROPERTY_KEY, exemplar);
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));

View File

@@ -61,12 +61,8 @@ public class LLMSqlParser implements SemanticParser {
// deduplicate the S2SQL result list and build parserInfo
sqlRespMap = responseService.getDeduplicationSqlResp(currentRetry, llmResp);
if (MapUtils.isNotEmpty(sqlRespMap)) {
parseResult =
ParseResult.builder()
.dataSetId(dataSetId)
.llmReq(llmReq)
.llmResp(llmResp)
.build();
parseResult = ParseResult.builder().dataSetId(dataSetId).llmReq(llmReq)
.llmResp(llmResp).build();
break;
}
}

View File

@@ -25,21 +25,19 @@ import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
public static final String INSTRUCTION =
""
+ "\n#Role: You are a data analyst experienced in SQL languages."
+ "\n#Task: You will be provided with a natural language question asked by users,"
+ "please convert it to a SQL query so that relevant data could be returned "
+ "by executing the SQL query against underlying database."
+ "\n#Rules:"
+ "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
+ "\n2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
+ "\n3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "\n4.DO NOT calculate date range using functions."
+ "\n5.DO NOT calculate date range using DATE_SUB."
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "\n#Exemplars:\n{{exemplar}}"
+ "\n#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}";
public static final String INSTRUCTION = ""
+ "\n#Role: You are a data analyst experienced in SQL languages."
+ "\n#Task: You will be provided with a natural language question asked by users,"
+ "please convert it to a SQL query so that relevant data could be returned "
+ "by executing the SQL query against underlying database." + "\n#Rules:"
+ "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate."
+ "\n2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
+ "\n3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "\n4.DO NOT calculate date range using functions."
+ "\n5.DO NOT calculate date range using DATE_SUB."
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "\n#Exemplars:\n{{exemplar}}"
+ "\n#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}";
@Data
static class SemanticSql {
@@ -75,21 +73,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
// 3.perform multiple self-consistency inferences parallelly
Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>();
prompt2Exemplar
.keySet()
.parallelStream()
.forEach(
prompt -> {
keyPipelineLog.info(
"OnePassSCSqlGenStrategy reqPrompt:\n{}",
prompt.toUserMessage());
SemanticSql s2Sql =
extractor.generateSemanticSql(
prompt.toUserMessage().singleText());
output2Prompt.put(s2Sql.getSql(), prompt);
keyPipelineLog.info(
"OnePassSCSqlGenStrategy modelResp:\n{}", s2Sql.getSql());
});
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
output2Prompt.put(s2Sql.getSql(), prompt);
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", s2Sql.getSql());
});
// 4.format response.
Pair<String, Map<String, Double>> sqlMapPair =
@@ -105,13 +94,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) {
StringBuilder exemplars = new StringBuilder();
for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) {
String exemplarStr =
String.format(
"Question:%s,Schema:%s,SideInfo:%s,SQL:%s\n",
exemplar.getQuestion(),
exemplar.getDbSchema(),
exemplar.getSideInfo(),
exemplar.getSql());
String exemplarStr = String.format("Question:%s,Schema:%s,SideInfo:%s,SQL:%s\n",
exemplar.getQuestion(), exemplar.getDbSchema(), exemplar.getSideInfo(),
exemplar.getSql());
exemplars.append(exemplarStr);
}
String dataSemantics = promptHelper.buildSchemaStr(llmReq);
@@ -136,7 +121,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
@Override
public void afterPropertiesSet() {
SqlGenStrategyFactory.addSqlGenerationForFactory(
LLMReq.SqlGenType.ONE_PASS_SELF_CONSISTENCY, this);
SqlGenStrategyFactory
.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_SELF_CONSISTENCY, this);
}
}

View File

@@ -24,9 +24,11 @@ import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_SE
@Slf4j
public class PromptHelper {
@Autowired private ParserConfig parserConfig;
@Autowired
private ParserConfig parserConfig;
@Autowired private ExemplarService exemplarService;
@Autowired
private ExemplarService exemplarService;
public List<List<Text2SQLExemplar>> getFewShotExemplars(LLMReq llmReq) {
int exemplarRecallNumber =
@@ -36,11 +38,9 @@ public class PromptHelper {
Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
List<Text2SQLExemplar> exemplars = Lists.newArrayList();
llmReq.getDynamicExemplars().stream()
.forEach(
e -> {
exemplars.add(e);
});
llmReq.getDynamicExemplars().stream().forEach(e -> {
exemplars.add(e);
});
int recallSize = exemplarRecallNumber - llmReq.getDynamicExemplars().size();
if (recallSize > 0) {
@@ -79,81 +79,65 @@ public class PromptHelper {
String tableStr = llmReq.getSchema().getDataSetName();
List<String> metrics = Lists.newArrayList();
llmReq.getSchema().getMetrics().stream()
.forEach(
metric -> {
StringBuilder metricStr = new StringBuilder();
metricStr.append("<");
metricStr.append(metric.getName());
if (!CollectionUtils.isEmpty(metric.getAlias())) {
StringBuilder alias = new StringBuilder();
metric.getAlias().stream().forEach(a -> alias.append(a + ","));
metricStr.append(" ALIAS '" + alias + "'");
}
if (StringUtils.isNotEmpty(metric.getDataFormatType())) {
String dataFormatType = metric.getDataFormatType();
if (DataFormatTypeEnum.DECIMAL
.getName()
.equalsIgnoreCase(dataFormatType)
|| DataFormatTypeEnum.PERCENT
.getName()
.equalsIgnoreCase(dataFormatType)) {
metricStr.append(" FORMAT '" + dataFormatType + "'");
}
}
if (StringUtils.isNotEmpty(metric.getDescription())) {
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
}
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
metricStr.append(
" AGGREGATE '"
+ metric.getDefaultAgg().toUpperCase()
+ "'");
}
metricStr.append(">");
metrics.add(metricStr.toString());
});
llmReq.getSchema().getMetrics().stream().forEach(metric -> {
StringBuilder metricStr = new StringBuilder();
metricStr.append("<");
metricStr.append(metric.getName());
if (!CollectionUtils.isEmpty(metric.getAlias())) {
StringBuilder alias = new StringBuilder();
metric.getAlias().stream().forEach(a -> alias.append(a + ","));
metricStr.append(" ALIAS '" + alias + "'");
}
if (StringUtils.isNotEmpty(metric.getDataFormatType())) {
String dataFormatType = metric.getDataFormatType();
if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType)
|| DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) {
metricStr.append(" FORMAT '" + dataFormatType + "'");
}
}
if (StringUtils.isNotEmpty(metric.getDescription())) {
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
}
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
}
metricStr.append(">");
metrics.add(metricStr.toString());
});
List<String> dimensions = Lists.newArrayList();
llmReq.getSchema().getDimensions().stream()
.forEach(
dimension -> {
StringBuilder dimensionStr = new StringBuilder();
dimensionStr.append("<");
dimensionStr.append(dimension.getName());
if (!CollectionUtils.isEmpty(dimension.getAlias())) {
StringBuilder alias = new StringBuilder();
dimension.getAlias().stream().forEach(a -> alias.append(a + ","));
dimensionStr.append(" ALIAS '" + alias + "'");
}
if (StringUtils.isNotEmpty(dimension.getTimeFormat())) {
dimensionStr.append(" FORMAT '" + dimension.getTimeFormat() + "'");
}
if (StringUtils.isNotEmpty(dimension.getDescription())) {
dimensionStr.append(
" COMMENT '" + dimension.getDescription() + "'");
}
dimensionStr.append(">");
dimensions.add(dimensionStr.toString());
});
llmReq.getSchema().getDimensions().stream().forEach(dimension -> {
StringBuilder dimensionStr = new StringBuilder();
dimensionStr.append("<");
dimensionStr.append(dimension.getName());
if (!CollectionUtils.isEmpty(dimension.getAlias())) {
StringBuilder alias = new StringBuilder();
dimension.getAlias().stream().forEach(a -> alias.append(a + ","));
dimensionStr.append(" ALIAS '" + alias + "'");
}
if (StringUtils.isNotEmpty(dimension.getTimeFormat())) {
dimensionStr.append(" FORMAT '" + dimension.getTimeFormat() + "'");
}
if (StringUtils.isNotEmpty(dimension.getDescription())) {
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
}
dimensionStr.append(">");
dimensions.add(dimensionStr.toString());
});
List<String> values = Lists.newArrayList();
llmReq.getSchema().getValues().stream()
.forEach(
value -> {
StringBuilder valueStr = new StringBuilder();
String fieldName = value.getFieldName();
String fieldValue = value.getFieldValue();
valueStr.append(String.format("<%s='%s'>", fieldName, fieldValue));
values.add(valueStr.toString());
});
llmReq.getSchema().getValues().stream().forEach(value -> {
StringBuilder valueStr = new StringBuilder();
String fieldName = value.getFieldName();
String fieldValue = value.getFieldValue();
valueStr.append(String.format("<%s='%s'>", fieldName, fieldValue));
values.add(valueStr.toString());
});
String partitionTimeStr = "";
if (llmReq.getSchema().getPartitionTime() != null) {
partitionTimeStr =
String.format(
"%s FORMAT '%s'",
llmReq.getSchema().getPartitionTime().getName(),
String.format("%s FORMAT '%s'", llmReq.getSchema().getPartitionTime().getName(),
llmReq.getSchema().getPartitionTime().getTimeFormat());
}
@@ -170,30 +154,19 @@ public class PromptHelper {
String template =
"DatabaseType=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], "
+ "Metrics=[%s], Dimensions=[%s], Values=[%s]";
return String.format(
template,
databaseTypeStr,
tableStr,
partitionTimeStr,
primaryKeyStr,
String.join(",", metrics),
String.join(",", dimensions),
String.join(",", values));
return String.format(template, databaseTypeStr, tableStr, partitionTimeStr, primaryKeyStr,
String.join(",", metrics), String.join(",", dimensions), String.join(",", values));
}
private String buildTermStr(LLMReq llmReq) {
List<LLMReq.Term> terms = llmReq.getTerms();
List<String> termStr = Lists.newArrayList();
terms.stream()
.forEach(
term -> {
StringBuilder termsDesc = new StringBuilder();
String description = term.getDescription();
termsDesc.append(
String.format(
"<%s COMMENT '%s'>", term.getName(), description));
termStr.add(termsDesc.toString());
});
terms.stream().forEach(term -> {
StringBuilder termsDesc = new StringBuilder();
String description = term.getDescription();
termsDesc.append(String.format("<%s COMMENT '%s'>", term.getName(), description));
termStr.add(termsDesc.toString());
});
String ret = "";
if (termStr.size() > 0) {
ret = String.join(",", termStr);

View File

@@ -54,19 +54,13 @@ public class ResponseHelper {
return Pair.of(inputMax, votePercentage);
}
public static Map<String, LLMSqlResp> buildSqlRespMap(
List<Text2SQLExemplar> sqlExamples, Map<String, Double> sqlMap) {
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Text2SQLExemplar> sqlExamples,
Map<String, Double> sqlMap) {
if (sqlMap == null) {
return new HashMap<>();
}
return sqlMap.entrySet().stream()
.collect(
Collectors.toMap(
Map.Entry::getKey,
entry ->
LLMSqlResp.builder()
.sqlWeight(entry.getValue())
.fewShots(sqlExamples)
.build()));
.collect(Collectors.toMap(Map.Entry::getKey, entry -> LLMSqlResp.builder()
.sqlWeight(entry.getValue()).fewShots(sqlExamples).build()));
}
}

View File

@@ -20,7 +20,8 @@ public abstract class SqlGenStrategy implements InitializingBean {
protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@Autowired protected PromptHelper promptHelper;
@Autowired
protected PromptHelper promptHelper;
protected ChatLanguageModel getChatLanguageModel(ChatModelConfig modelConfig) {
return ModelProvider.getChatModel(modelConfig);

View File

@@ -14,8 +14,8 @@ public class SqlGenStrategyFactory {
return sqlGenStrategyMap.get(strategyType);
}
public static void addSqlGenerationForFactory(
LLMReq.SqlGenType strategy, SqlGenStrategy sqlGenStrategy) {
public static void addSqlGenerationForFactory(LLMReq.SqlGenType strategy,
SqlGenStrategy sqlGenStrategy) {
sqlGenStrategyMap.put(strategy, sqlGenStrategy);
}
}

View File

@@ -27,27 +27,20 @@ import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.DISTINC
@Slf4j
public class AggregateTypeParser implements SemanticParser {
private static final Map<AggregateTypeEnum, Pattern> REGX_MAP =
Stream.of(
new AbstractMap.SimpleEntry<>(
AggregateTypeEnum.MAX,
Pattern.compile("(?i)(最值|最|max|峰值|最|最)")),
new AbstractMap.SimpleEntry<>(
AggregateTypeEnum.MIN,
Pattern.compile("(?i)(最小值|最小|min|最低|最少)")),
new AbstractMap.SimpleEntry<>(
AggregateTypeEnum.SUM, Pattern.compile("(?i)(汇总|总和|sum)")),
new AbstractMap.SimpleEntry<>(
AggregateTypeEnum.AVG, Pattern.compile("(?i)(平均值|日均|平均|avg)")),
new AbstractMap.SimpleEntry<>(
AggregateTypeEnum.TOPN, Pattern.compile("(?i)(top)")),
new AbstractMap.SimpleEntry<>(DISTINCT, Pattern.compile("(?i)(uv)")),
new AbstractMap.SimpleEntry<>(COUNT, Pattern.compile("(?i)(总数|pv)")),
new AbstractMap.SimpleEntry<>(
AggregateTypeEnum.NONE, Pattern.compile("(?i)(明细)")))
.collect(
Collectors.toMap(
Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
private static final Map<AggregateTypeEnum, Pattern> REGX_MAP = Stream.of(
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.MAX,
Pattern.compile("(?i)(最大值|最大|max|峰值|最高|最多)")),
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.MIN,
Pattern.compile("(?i)(最值|最|min|最|最)")),
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.SUM,
Pattern.compile("(?i)(汇总|总和|sum)")),
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.AVG,
Pattern.compile("(?i)(平均值|日均|平均|avg)")),
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.TOPN, Pattern.compile("(?i)(top)")),
new AbstractMap.SimpleEntry<>(DISTINCT, Pattern.compile("(?i)(uv)")),
new AbstractMap.SimpleEntry<>(COUNT, Pattern.compile("(?i)(总数|pv)")),
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.NONE, Pattern.compile("(?i)(明细)")))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
@Override
public void parse(ChatQueryContext chatQueryContext) {
@@ -63,8 +56,7 @@ public class AggregateTypeParser implements SemanticParser {
if (StringUtils.isNotEmpty(aggregateConf.detectWord)) {
detectWordLength = aggregateConf.detectWord.length();
}
semanticQuery
.getParseInfo()
semanticQuery.getParseInfo()
.setScore(semanticQuery.getParseInfo().getScore() + detectWordLength);
}
}
@@ -93,10 +85,8 @@ public class AggregateTypeParser implements SemanticParser {
}
AggregateTypeEnum type =
aggregateCount.entrySet().stream()
.max(Map.Entry.comparingByValue())
.map(entry -> entry.getKey())
.orElse(AggregateTypeEnum.NONE);
aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue())
.map(entry -> entry.getKey()).orElse(AggregateTypeEnum.NONE);
String detectWord = aggregateWord.get(type);
return new AggregateConf(type, detectWord);
}

View File

@@ -32,25 +32,18 @@ public class ContextInheritParser implements SemanticParser {
private static final Map<SchemaElementType, List<SchemaElementType>> MUTUAL_EXCLUSIVE_MAP =
Stream.of(
new AbstractMap.SimpleEntry<>(
SchemaElementType.METRIC,
Arrays.asList(SchemaElementType.METRIC)),
new AbstractMap.SimpleEntry<>(
SchemaElementType.DIMENSION,
Arrays.asList(
SchemaElementType.DIMENSION, SchemaElementType.VALUE)),
new AbstractMap.SimpleEntry<>(
SchemaElementType.VALUE,
Arrays.asList(
SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
new AbstractMap.SimpleEntry<>(
SchemaElementType.ENTITY,
Arrays.asList(SchemaElementType.ENTITY)),
new AbstractMap.SimpleEntry<>(
SchemaElementType.DATASET,
Arrays.asList(SchemaElementType.DATASET)),
new AbstractMap.SimpleEntry<>(
SchemaElementType.ID, Arrays.asList(SchemaElementType.ID)))
new AbstractMap.SimpleEntry<>(SchemaElementType.METRIC,
Arrays.asList(SchemaElementType.METRIC)),
new AbstractMap.SimpleEntry<>(SchemaElementType.DIMENSION,
Arrays.asList(SchemaElementType.DIMENSION, SchemaElementType.VALUE)),
new AbstractMap.SimpleEntry<>(SchemaElementType.VALUE,
Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY,
Arrays.asList(SchemaElementType.ENTITY)),
new AbstractMap.SimpleEntry<>(SchemaElementType.DATASET,
Arrays.asList(SchemaElementType.DATASET)),
new AbstractMap.SimpleEntry<>(SchemaElementType.ID,
Arrays.asList(SchemaElementType.ID)))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
@Override
@@ -67,13 +60,12 @@ public class ContextInheritParser implements SemanticParser {
chatQueryContext.getMapInfo().getMatchedElements(dataSetId);
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
for (SchemaElementMatch match :
chatQueryContext.getContextParseInfo().getElementMatches()) {
for (SchemaElementMatch match : chatQueryContext.getContextParseInfo()
.getElementMatches()) {
SchemaElementType matchType = match.getElement().getType();
// mutual exclusive element types should not be inherited
RuleSemanticQuery ruleQuery =
QueryManager.getRuleQuery(
chatQueryContext.getContextParseInfo().getQueryMode());
RuleSemanticQuery ruleQuery = QueryManager
.getRuleQuery(chatQueryContext.getContextParseInfo().getQueryMode());
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
match.setInherited(true);
matchesToInherit.add(match);
@@ -85,16 +77,16 @@ public class ContextInheritParser implements SemanticParser {
RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(chatQueryContext);
if (existSameQuery(
query.getParseInfo().getDataSetId(), query.getQueryMode(), chatQueryContext)) {
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(),
chatQueryContext)) {
continue;
}
chatQueryContext.getCandidateQueries().add(query);
}
}
private boolean existSameQuery(
Long dataSetId, String queryMode, ChatQueryContext chatQueryContext) {
private boolean existSameQuery(Long dataSetId, String queryMode,
ChatQueryContext chatQueryContext) {
for (SemanticQuery semanticQuery : chatQueryContext.getCandidateQueries()) {
if (semanticQuery.getQueryMode().equals(queryMode)
&& semanticQuery.getParseInfo().getDataSetId().equals(dataSetId)) {
@@ -104,33 +96,26 @@ public class ContextInheritParser implements SemanticParser {
return false;
}
private boolean containsTypes(
List<SchemaElementMatch> matches,
SchemaElementType matchType,
private boolean containsTypes(List<SchemaElementMatch> matches, SchemaElementType matchType,
RuleSemanticQuery ruleQuery) {
List<SchemaElementType> types = MUTUAL_EXCLUSIVE_MAP.get(matchType);
return matches.stream()
.anyMatch(
m -> {
SchemaElementType type = m.getElement().getType();
if (Objects.nonNull(ruleQuery)
&& ruleQuery instanceof MetricSemanticQuery
&& !(ruleQuery instanceof MetricIdQuery)) {
return types.contains(type);
}
return type.equals(matchType);
});
return matches.stream().anyMatch(m -> {
SchemaElementType type = m.getElement().getType();
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
&& !(ruleQuery instanceof MetricIdQuery)) {
return types.contains(type);
}
return type.equals(matchType);
});
}
protected boolean shouldInherit(ChatQueryContext chatQueryContext) {
// if candidates only have MetricModel mode, count in context
List<SemanticQuery> metricModelQueries =
chatQueryContext.getCandidateQueries().stream()
.filter(
query ->
query instanceof MetricModelQuery
|| query instanceof DetailDimensionQuery)
.filter(query -> query instanceof MetricModelQuery
|| query instanceof DetailDimensionQuery)
.collect(Collectors.toList());
return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size();
}

View File

@@ -17,9 +17,8 @@ import java.util.List;
@Slf4j
public class RuleSqlParser implements SemanticParser {
private static List<SemanticParser> auxiliaryParsers =
Arrays.asList(
new ContextInheritParser(), new TimeRangeParser(), new AggregateTypeParser());
private static List<SemanticParser> auxiliaryParsers = Arrays.asList(new ContextInheritParser(),
new TimeRangeParser(), new AggregateTypeParser());
@Override
public void parse(ChatQueryContext chatQueryContext) {

View File

@@ -30,9 +30,8 @@ import java.util.regex.Pattern;
@Slf4j
public class TimeRangeParser implements SemanticParser {
private static final Pattern RECENT_PATTERN_CN =
Pattern.compile(
".*(?<periodStr>(近|过去)((?<enNum>\\d+)|(?<zhNum>[一二三四五六七八九十百千万亿]+))个?(?<zhPeriod>[天周月年])).*");
private static final Pattern RECENT_PATTERN_CN = Pattern.compile(
".*(?<periodStr>(近|过去)((?<enNum>\\d+)|(?<zhNum>[一二三四五六七八九十百千万亿]+))个?(?<zhPeriod>[天周月年])).*");
private static final Pattern DATE_PATTERN_NUMBER = Pattern.compile("(\\d{8})");
private static final DateFormat DATE_FORMAT_NUMBER = new SimpleDateFormat("yyyyMMdd");
private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd");
@@ -70,8 +69,8 @@ public class TimeRangeParser implements SemanticParser {
if (queryContext.containsPartitionDimensions(contextParseInfo.getDataSetId())) {
contextParseInfo.setDateInfo(dateConf);
}
contextParseInfo.setScore(
contextParseInfo.getScore() + dateConf.getDetectWord().length());
contextParseInfo
.setScore(contextParseInfo.getScore() + dateConf.getDetectWord().length());
semanticQuery.setParseInfo(contextParseInfo);
queryContext.getCandidateQueries().add(semanticQuery);
}

View File

@@ -52,8 +52,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
}
protected void convertBizNameToName(
DataSetSchema dataSetSchema, QueryStructReq queryStructReq) {
protected void convertBizNameToName(DataSetSchema dataSetSchema,
QueryStructReq queryStructReq) {
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
@@ -76,8 +76,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
}
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
dimensionFilters.forEach(
filter -> filter.setName(bizNameToName.get(filter.getBizName())));
dimensionFilters
.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
}
List<Filter> metricFilters = queryStructReq.getMetricFilters();
if (CollectionUtils.isNotEmpty(dimensionFilters)) {

View File

@@ -4,4 +4,5 @@ import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public abstract class LLMSemanticQuery extends BaseSemanticQuery {}
public abstract class LLMSemanticQuery extends BaseSemanticQuery {
}

View File

@@ -46,16 +46,12 @@ public class LLMReq {
public List<String> getFieldNameList() {
List<String> fieldNameList = new ArrayList<>();
if (CollectionUtils.isNotEmpty(metrics)) {
fieldNameList.addAll(
metrics.stream()
.map(metric -> metric.getName())
.collect(Collectors.toList()));
fieldNameList.addAll(metrics.stream().map(metric -> metric.getName())
.collect(Collectors.toList()));
}
if (CollectionUtils.isNotEmpty(dimensions)) {
fieldNameList.addAll(
dimensions.stream()
.map(dimension -> dimension.getName())
.collect(Collectors.toList()));
fieldNameList.addAll(dimensions.stream().map(dimension -> dimension.getName())
.collect(Collectors.toList()));
}
if (Objects.nonNull(partitionTime)) {
fieldNameList.add(partitionTime.getName());
@@ -76,6 +72,7 @@ public class LLMReq {
public enum SqlGenType {
ONE_PASS_SELF_CONSISTENCY("1_pass_self_consistency");
private String name;
SqlGenType(String name) {

View File

@@ -9,10 +9,8 @@ public class QueryMatchOption {
private RequireNumberType requireNumberType;
private Integer requireNumber;
public static QueryMatchOption build(
OptionType schemaElementOption,
RequireNumberType requireNumberType,
Integer requireNumber) {
public static QueryMatchOption build(OptionType schemaElementOption,
RequireNumberType requireNumberType, Integer requireNumber) {
QueryMatchOption queryMatchOption = new QueryMatchOption();
queryMatchOption.requireNumber = requireNumber;
queryMatchOption.requireNumberType = requireNumberType;
@@ -37,14 +35,10 @@ public class QueryMatchOption {
}
public enum RequireNumberType {
AT_MOST,
AT_LEAST,
EQUAL
AT_MOST, AT_LEAST, EQUAL
}
public enum OptionType {
REQUIRED,
OPTIONAL,
UNUSED
REQUIRED, OPTIONAL, UNUSED
}
}

View File

@@ -33,13 +33,10 @@ public class QueryMatcher {
}
}
public QueryMatcher addOption(
SchemaElementType type,
QueryMatchOption.OptionType option,
QueryMatchOption.RequireNumberType requireNumberType,
Integer requireNumber) {
elementOptionMap.put(
type, QueryMatchOption.build(option, requireNumberType, requireNumber));
public QueryMatcher addOption(SchemaElementType type, QueryMatchOption.OptionType option,
QueryMatchOption.RequireNumberType requireNumberType, Integer requireNumber) {
elementOptionMap.put(type,
QueryMatchOption.build(option, requireNumberType, requireNumber));
return this;
}
@@ -55,8 +52,8 @@ public class QueryMatcher {
for (SchemaElementMatch schemaElementMatch : candidateElementMatches) {
SchemaElementType schemaElementType = schemaElementMatch.getElement().getType();
if (schemaElementTypeCount.containsKey(schemaElementType)) {
schemaElementTypeCount.put(
schemaElementType, schemaElementTypeCount.get(schemaElementType) + 1);
schemaElementTypeCount.put(schemaElementType,
schemaElementTypeCount.get(schemaElementType) + 1);
} else {
schemaElementTypeCount.put(schemaElementType, 1);
}
@@ -75,10 +72,8 @@ public class QueryMatcher {
for (SchemaElementMatch elementMatch : candidateElementMatches) {
QueryMatchOption elementOption =
elementOptionMap.get(elementMatch.getElement().getType());
if (Objects.nonNull(elementOption)
&& !elementOption
.getSchemaElementOption()
.equals(QueryMatchOption.OptionType.UNUSED)) {
if (Objects.nonNull(elementOption) && !elementOption.getSchemaElementOption()
.equals(QueryMatchOption.OptionType.UNUSED)) {
elementMatches.add(elementMatch);
}
}
@@ -86,8 +81,7 @@ public class QueryMatcher {
return elementMatches;
}
private int getCount(
HashMap<SchemaElementType, Integer> schemaElementTypeCount,
private int getCount(HashMap<SchemaElementType, Integer> schemaElementTypeCount,
SchemaElementType schemaElementType) {
if (schemaElementTypeCount.containsKey(schemaElementType)) {
return schemaElementTypeCount.get(schemaElementType);
@@ -101,15 +95,13 @@ public class QueryMatcher {
&& count <= 0) {
return false;
}
if (queryMatchOption
.getRequireNumberType()
.equals(QueryMatchOption.RequireNumberType.AT_LEAST)
if (queryMatchOption.getRequireNumberType()
.equals(QueryMatchOption.RequireNumberType.AT_LEAST)
&& count < queryMatchOption.getRequireNumber()) {
return false;
}
if (queryMatchOption
.getRequireNumberType()
.equals(QueryMatchOption.RequireNumberType.AT_MOST)
if (queryMatchOption.getRequireNumberType()
.equals(QueryMatchOption.RequireNumberType.AT_MOST)
&& count > queryMatchOption.getRequireNumber()) {
return false;
}

View File

@@ -40,8 +40,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
QueryManager.register(this);
}
public List<SchemaElementMatch> match(
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext queryCtx) {
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
ChatQueryContext queryCtx) {
return queryMatcher.match(candidateElementMatches);
}
@@ -67,17 +67,16 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
return chatQueryContext.containsPartitionDimensions(dataSetId);
}
private void fillDateConfByInherited(
SemanticParseInfo queryParseInfo, ChatQueryContext chatQueryContext) {
private void fillDateConfByInherited(SemanticParseInfo queryParseInfo,
ChatQueryContext chatQueryContext) {
SemanticParseInfo contextParseInfo = chatQueryContext.getContextParseInfo();
if (queryParseInfo.getDateInfo() != null
|| contextParseInfo.getDateInfo() == null
if (queryParseInfo.getDateInfo() != null || contextParseInfo.getDateInfo() == null
|| needFillDateConf(chatQueryContext)) {
return;
}
if ((QueryManager.isDetailQuery(queryParseInfo.getQueryMode())
&& QueryManager.isDetailQuery(contextParseInfo.getQueryMode()))
&& QueryManager.isDetailQuery(contextParseInfo.getQueryMode()))
|| (QueryManager.isMetricQuery(queryParseInfo.getQueryMode())
&& QueryManager.isMetricQuery(contextParseInfo.getQueryMode()))) {
// inherit date info from context
@@ -107,10 +106,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
Set<Long> dataSetIds =
parseInfo.getElementMatches().stream()
.map(SchemaElementMatch::getElement)
.map(SchemaElement::getDataSetId)
.collect(Collectors.toSet());
parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
.map(SchemaElement::getDataSetId).collect(Collectors.toSet());
Long dataSetId = dataSetIds.iterator().next();
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId));
@@ -128,8 +125,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
if (id2Values.containsKey(element.getId())) {
id2Values.get(element.getId()).add(schemaMatch);
} else {
id2Values.put(
element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
id2Values.put(element.getId(),
new ArrayList<>(Arrays.asList(schemaMatch)));
}
}
break;
@@ -140,8 +137,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
if (dim2Values.containsKey(element.getId())) {
dim2Values.get(element.getId()).add(schemaMatch);
} else {
dim2Values.put(
element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
dim2Values.put(element.getId(),
new ArrayList<>(Arrays.asList(schemaMatch)));
}
}
break;
@@ -161,11 +158,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
addToFilters(dim2Values, parseInfo, semanticSchema, SchemaElementType.DIMENSION);
}
private void addToFilters(
Map<Long, List<SchemaElementMatch>> id2Values,
SemanticParseInfo parseInfo,
SemanticSchema semanticSchema,
SchemaElementType entity) {
private void addToFilters(Map<Long, List<SchemaElementMatch>> id2Values,
SemanticParseInfo parseInfo, SemanticSchema semanticSchema, SchemaElementType entity) {
if (id2Values == null || id2Values.isEmpty()) {
return;
}
@@ -206,8 +200,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
public SemanticQueryReq multiStructExecute() {
String queryMode = parseInfo.getQueryMode();
if (parseInfo.getDataSetId() != null
|| StringUtils.isEmpty(queryMode)
if (parseInfo.getDataSetId() != null || StringUtils.isEmpty(queryMode)
|| !QueryManager.containsRuleQuery(queryMode)) {
// reach here some error may happen
log.error("not find QueryMode");
@@ -222,10 +215,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
this.parseInfo = parseInfo;
}
public static List<RuleSemanticQuery> resolve(
Long dataSetId,
List<SchemaElementMatch> candidateElementMatches,
ChatQueryContext chatQueryContext) {
public static List<RuleSemanticQuery> resolve(Long dataSetId,
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext chatQueryContext) {
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
List<SchemaElementMatch> matches =

View File

@@ -20,8 +20,8 @@ public abstract class DetailListQuery extends DetailSemanticQuery {
this.addEntityDetailAndOrderByMetric(chatQueryContext, parseInfo);
}
private void addEntityDetailAndOrderByMetric(
ChatQueryContext chatQueryContext, SemanticParseInfo parseInfo) {
private void addEntityDetailAndOrderByMetric(ChatQueryContext chatQueryContext,
SemanticParseInfo parseInfo) {
Long dataSetId = parseInfo.getDataSetId();
if (Objects.isNull(dataSetId) || dataSetId <= 0L) {
return;
@@ -38,35 +38,23 @@ public abstract class DetailListQuery extends DetailSemanticQuery {
&& detailTypeDefaultConfig.getDefaultDisplayInfo() != null) {
if (CollectionUtils.isNotEmpty(
detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) {
metrics =
detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds().stream()
.map(
id -> {
SchemaElement metric =
dataSetSchema.getElement(
SchemaElementType.METRIC, id);
if (metric != null) {
orders.add(
new Order(
metric.getBizName(),
Constants.DESC_UPPER));
}
return metric;
})
.filter(Objects::nonNull)
.collect(Collectors.toSet());
metrics = detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds()
.stream().map(id -> {
SchemaElement metric =
dataSetSchema.getElement(SchemaElementType.METRIC, id);
if (metric != null) {
orders.add(
new Order(metric.getBizName(), Constants.DESC_UPPER));
}
return metric;
}).filter(Objects::nonNull).collect(Collectors.toSet());
}
if (CollectionUtils.isNotEmpty(
detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) {
dimensions =
detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds()
.stream()
.map(
id ->
dataSetSchema.getElement(
SchemaElementType.DIMENSION, id))
.filter(Objects::nonNull)
.collect(Collectors.toSet());
dimensions = detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds()
.stream()
.map(id -> dataSetSchema.getElement(SchemaElementType.DIMENSION, id))
.filter(Objects::nonNull).collect(Collectors.toSet());
}
}
parseInfo.setDimensions(dimensions);

View File

@@ -23,8 +23,8 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
}
@Override
public List<SchemaElementMatch> match(
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext queryCtx) {
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
ChatQueryContext queryCtx) {
return super.match(candidateElementMatches, queryCtx);
}
@@ -43,8 +43,7 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
DataSetSchema dataSetSchema = dataSetSchemaMap.get(parseInfo.getDataSetId());
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
if (Objects.nonNull(timeDefaultConfig)
&& Objects.nonNull(timeDefaultConfig.getUnit())
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())
&& timeDefaultConfig.getUnit() != -1) {
DateConf dateInfo = new DateConf();
int unit = timeDefaultConfig.getUnit();

View File

@@ -71,19 +71,15 @@ public class MetricFilterQuery extends MetricSemanticQuery {
log.debug("addDimension before [{}]", queryStructReq.getGroups());
List<Filter> filters = new ArrayList<>(queryStructReq.getDimensionFilters());
if (onlyOperateInFilter) {
filters =
filters.stream()
.filter(
filter ->
filter.getOperator().equals(FilterOperatorEnum.IN))
.collect(Collectors.toList());
filters = filters.stream()
.filter(filter -> filter.getOperator().equals(FilterOperatorEnum.IN))
.collect(Collectors.toList());
}
filters.forEach(
d -> {
if (!dimensions.contains(d.getBizName())) {
dimensions.add(d.getBizName());
}
});
filters.forEach(d -> {
if (!dimensions.contains(d.getBizName())) {
dimensions.add(d.getBizName());
}
});
queryStructReq.setGroups(dimensions);
log.debug("addDimension after [{}]", queryStructReq.getGroups());
}

View File

@@ -46,8 +46,7 @@ public class MetricIdQuery extends MetricSemanticQuery {
protected boolean isMultiStructQuery() {
Set<String> filterBizName = new HashSet<>();
parseInfo.getDimensionFilters().stream()
.filter(filter -> filter.getElementID() != null)
parseInfo.getDimensionFilters().stream().filter(filter -> filter.getElementID() != null)
.forEach(filter -> filterBizName.add(filter.getBizName()));
return FilterType.UNION.equals(parseInfo.getFilterType()) && filterBizName.size() > 1;
}
@@ -74,19 +73,15 @@ public class MetricIdQuery extends MetricSemanticQuery {
log.info("addDimension before [{}]", queryStructReq.getGroups());
List<Filter> filters = new ArrayList<>(queryStructReq.getDimensionFilters());
if (onlyOperateInFilter) {
filters =
filters.stream()
.filter(
filter ->
filter.getOperator().equals(FilterOperatorEnum.IN))
.collect(Collectors.toList());
filters = filters.stream()
.filter(filter -> filter.getOperator().equals(FilterOperatorEnum.IN))
.collect(Collectors.toList());
}
filters.forEach(
d -> {
if (!dimensions.contains(d.getBizName())) {
dimensions.add(d.getBizName());
}
});
filters.forEach(d -> {
if (!dimensions.contains(d.getBizName())) {
dimensions.add(d.getBizName());
}
});
queryStructReq.setGroups(dimensions);
log.info("addDimension after [{}]", queryStructReq.getGroups());
}

View File

@@ -26,8 +26,8 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
}
@Override
public List<SchemaElementMatch> match(
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext queryCtx) {
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
ChatQueryContext queryCtx) {
return super.match(candidateElementMatches, queryCtx);
}
@@ -42,16 +42,12 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
if (parseInfo.getDateInfo() != null || !needFillDateConf(chatQueryContext)) {
return;
}
DataSetSchema dataSetSchema =
chatQueryContext
.getSemanticSchema()
.getDataSetSchemaMap()
.get(parseInfo.getDataSetId());
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap()
.get(parseInfo.getDataSetId());
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig();
DateConf dateInfo = new DateConf();
// 加上时间!=-1 判断
if (Objects.nonNull(timeDefaultConfig)
&& Objects.nonNull(timeDefaultConfig.getUnit())
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())
&& timeDefaultConfig.getUnit() != -1) {
int unit = timeDefaultConfig.getUnit();
String startDate = LocalDate.now().minusDays(unit).toString();

View File

@@ -33,8 +33,8 @@ public class MetricTopNQuery extends MetricSemanticQuery {
}
@Override
public List<SchemaElementMatch> match(
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext queryCtx) {
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
ChatQueryContext queryCtx) {
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getQueryText());
if (matcher.matches()) {
return super.match(candidateElementMatches, queryCtx);

View File

@@ -26,15 +26,13 @@ public class ComponentFactory {
}
private static <T> List<T> init(Class<T> factoryType, List list) {
list.addAll(
SpringFactoriesLoader.loadFactories(
factoryType, Thread.currentThread().getContextClassLoader()));
list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
Thread.currentThread().getContextClassLoader()));
return list;
}
private static <T> T init(Class<T> factoryType) {
return SpringFactoriesLoader.loadFactories(
factoryType, Thread.currentThread().getContextClassLoader())
.get(0);
return SpringFactoriesLoader
.loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0);
}
}

View File

@@ -20,8 +20,7 @@ public class EditDistanceUtils {
public static double getSimilarity(String detectSegment, String matchName) {
String detectSegmentLower = detectSegment == null ? null : detectSegment.toLowerCase();
String matchNameLower = matchName == null ? null : matchName.toLowerCase();
return 1
- (double) EditDistance.compute(detectSegmentLower, matchNameLower)
/ Math.max(matchName.length(), detectSegment.length());
return 1 - (double) EditDistance.compute(detectSegmentLower, matchNameLower)
/ Math.max(matchName.length(), detectSegment.length());
}
}

View File

@@ -13,10 +13,8 @@ public class QueryFilterParser {
public static String parse(QueryFilters queryFilters) {
try {
List<String> conditions =
queryFilters.getFilters().stream()
.map(QueryFilterParser::parseFilter)
.collect(Collectors.toList());
List<String> conditions = queryFilters.getFilters().stream()
.map(QueryFilterParser::parseFilter).collect(Collectors.toList());
return String.join(" AND ", conditions);
} catch (Exception e) {
log.error("", e);
@@ -36,10 +34,7 @@ public class QueryFilterParser {
case BETWEEN:
if (value instanceof List && ((List<?>) value).size() == 2) {
List<?> values = (List<?>) value;
return column
+ " BETWEEN "
+ formatValue(values.get(0))
+ " AND "
return column + " BETWEEN " + formatValue(values.get(0)) + " AND "
+ formatValue(values.get(1));
}
throw new IllegalArgumentException(
@@ -58,8 +53,8 @@ public class QueryFilterParser {
private static String parseList(Object value) {
if (value instanceof List) {
return ((List<?>) value)
.stream().map(QueryFilterParser::formatValue).collect(Collectors.joining(", "));
return ((List<?>) value).stream().map(QueryFilterParser::formatValue)
.collect(Collectors.joining(", "));
}
throw new IllegalArgumentException("IN and NOT IN operators require a list of values");
}

View File

@@ -46,15 +46,10 @@ public class QueryReqBuilder {
List<Filter> dimensionFilters = getFilters(parseInfo.getDimensionFilters());
queryStructReq.setDimensionFilters(dimensionFilters);
List<Filter> metricFilters =
parseInfo.getMetricFilters().stream()
.map(
chatFilter ->
new Filter(
chatFilter.getBizName(),
chatFilter.getOperator(),
chatFilter.getValue()))
.collect(Collectors.toList());
List<Filter> metricFilters = parseInfo
.getMetricFilters().stream().map(chatFilter -> new Filter(chatFilter.getBizName(),
chatFilter.getOperator(), chatFilter.getValue()))
.collect(Collectors.toList());
queryStructReq.setMetricFilters(metricFilters);
addDateDimension(parseInfo);
@@ -62,10 +57,8 @@ public class QueryReqBuilder {
if (isDateFieldAlreadyPresent(parseInfo, getDateField(parseInfo.getDateInfo()))) {
parseInfo.getDimensions().removeIf(schemaElement -> schemaElement.isPartitionTime());
}
queryStructReq.setGroups(
parseInfo.getDimensions().stream()
.map(SchemaElement::getBizName)
.collect(Collectors.toList()));
queryStructReq.setGroups(parseInfo.getDimensions().stream().map(SchemaElement::getBizName)
.collect(Collectors.toList()));
queryStructReq.setLimit(parseInfo.getLimit());
// only one metric is queried at once
Set<SchemaElement> metrics = parseInfo.getMetrics();
@@ -73,8 +66,8 @@ public class QueryReqBuilder {
SchemaElement metricElement = parseInfo.getMetrics().iterator().next();
Set<Order> order =
getOrder(parseInfo.getOrders(), parseInfo.getAggType(), metricElement);
queryStructReq.setAggregators(
getAggregatorByMetric(parseInfo.getAggType(), metricElement));
queryStructReq
.setAggregators(getAggregatorByMetric(parseInfo.getAggType(), metricElement));
queryStructReq.setOrders(new ArrayList<>(order));
}
@@ -87,12 +80,8 @@ public class QueryReqBuilder {
List<Filter> dimensionFilters =
queryFilters.stream()
.filter(chatFilter -> StringUtils.isNotEmpty(chatFilter.getBizName()))
.map(
chatFilter ->
new Filter(
chatFilter.getBizName(),
chatFilter.getOperator(),
chatFilter.getValue()))
.map(chatFilter -> new Filter(chatFilter.getBizName(),
chatFilter.getOperator(), chatFilter.getValue()))
.collect(Collectors.toList());
return dimensionFilters;
}
@@ -149,21 +138,20 @@ public class QueryReqBuilder {
return querySQLReq;
}
private static List<Aggregator> getAggregatorByMetric(
AggregateTypeEnum aggregateType, SchemaElement metric) {
private static List<Aggregator> getAggregatorByMetric(AggregateTypeEnum aggregateType,
SchemaElement metric) {
if (metric == null) {
return Collections.emptyList();
}
String agg = determineAggregator(aggregateType, metric);
return Collections.singletonList(
new Aggregator(metric.getBizName(), AggOperatorEnum.of(agg)));
return Collections
.singletonList(new Aggregator(metric.getBizName(), AggOperatorEnum.of(agg)));
}
private static String determineAggregator(
AggregateTypeEnum aggregateType, SchemaElement metric) {
if (aggregateType == null
|| aggregateType.equals(AggregateTypeEnum.NONE)
private static String determineAggregator(AggregateTypeEnum aggregateType,
SchemaElement metric) {
if (aggregateType == null || aggregateType.equals(AggregateTypeEnum.NONE)
|| AggOperatorEnum.COUNT_DISTINCT.name().equalsIgnoreCase(metric.getDefaultAgg())) {
return StringUtils.defaultIfBlank(metric.getDefaultAgg(), "");
}
@@ -199,28 +187,24 @@ public class QueryReqBuilder {
&& !CollectionUtils.isEmpty(parseInfo.getDimensions());
}
private static boolean isDateFieldAlreadyPresent(
SemanticParseInfo parseInfo, String dateField) {
private static boolean isDateFieldAlreadyPresent(SemanticParseInfo parseInfo,
String dateField) {
return parseInfo.getDimensions().stream()
.anyMatch(dimension -> dimension.getBizName().equalsIgnoreCase(dateField));
}
private static void addDimension(SemanticParseInfo parseInfo, SchemaElement dimension) {
List<String> timeDimensions =
Arrays.asList(
TimeDimensionEnum.DAY.getName(),
TimeDimensionEnum.WEEK.getName(),
TimeDimensionEnum.MONTH.getName());
Set<SchemaElement> dimensions =
parseInfo.getDimensions().stream()
.filter(d -> !timeDimensions.contains(d.getBizName().toLowerCase()))
.collect(Collectors.toSet());
List<String> timeDimensions = Arrays.asList(TimeDimensionEnum.DAY.getName(),
TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.MONTH.getName());
Set<SchemaElement> dimensions = parseInfo.getDimensions().stream()
.filter(d -> !timeDimensions.contains(d.getBizName().toLowerCase()))
.collect(Collectors.toSet());
dimensions.add(dimension);
parseInfo.setDimensions(dimensions);
}
public static Set<Order> getOrder(
Set<Order> existingOrders, AggregateTypeEnum aggregator, SchemaElement metric) {
public static Set<Order> getOrder(Set<Order> existingOrders, AggregateTypeEnum aggregator,
SchemaElement metric) {
if (existingOrders != null && !existingOrders.isEmpty()) {
return existingOrders;
}
@@ -230,8 +214,7 @@ public class QueryReqBuilder {
}
Set<Order> orders = new LinkedHashSet<>();
if (aggregator == AggregateTypeEnum.TOPN
|| aggregator == AggregateTypeEnum.MAX
if (aggregator == AggregateTypeEnum.TOPN || aggregator == AggregateTypeEnum.MAX
|| aggregator == AggregateTypeEnum.MIN) {
Order order = new Order();
order.setColumn(metric.getBizName());
@@ -256,8 +239,8 @@ public class QueryReqBuilder {
return dateField;
}
public static QueryStructReq buildStructRatioReq(
SemanticParseInfo parseInfo, SchemaElement metric, AggOperatorEnum aggOperatorEnum) {
public static QueryStructReq buildStructRatioReq(SemanticParseInfo parseInfo,
SchemaElement metric, AggOperatorEnum aggOperatorEnum) {
QueryStructReq queryStructReq = buildStructReq(parseInfo);
queryStructReq.setQueryType(QueryType.AGGREGATE);
queryStructReq.setOrders(new ArrayList<>());

View File

@@ -27,10 +27,9 @@ class AggCorrectorTest {
dataSet.setDataSetId(dataSetId);
semanticParseInfo.setDataSet(dataSet);
SqlInfo sqlInfo = new SqlInfo();
String sql =
"SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
+ " datediff('day', 数据日期, '2024-06-04') <= 7"
+ " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1";
String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND"
+ " datediff('day', 数据日期, '2024-06-04') <= 7"
+ " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1";
sqlInfo.setParsedS2SQL(sql);
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);

View File

@@ -24,26 +24,20 @@ import java.util.Set;
@Disabled
class SchemaCorrectorTest {
private String json =
"{\n"
+ " \"dataSetId\": 1,\n"
+ " \"llmReq\": {\n"
+ " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n"
+ " \"schema\": {\n"
+ " \"dataSetName\": \"歌曲\",\n"
+ " \"fieldNameList\": [\n"
+ " \"商务组\",\n"
+ " \"歌曲名\",\n"
+ " \"播放量\",\n"
+ " \"播放份额\",\n"
+ " \"数据日期\"\n"
+ " ]\n"
+ " },\n"
+ " \"currentDate\": \"2024-02-24\",\n"
+ " \"sqlGenType\": \"1_pass_self_consistency\"\n"
+ " },\n"
+ " \"request\": null\n"
+ "}";
private String json = "{\n" + " \"dataSetId\": 1,\n" + " \"llmReq\": {\n"
+ " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n"
+ " \"schema\": {\n"
+ " \"dataSetName\": \"歌曲\",\n"
+ " \"fieldNameList\": [\n"
+ " \"商务组\",\n"
+ " \"歌曲\",\n"
+ " \"播放量\",\n"
+ " \"播放份额\",\n"
+ " \"数据日期\"\n"
+ " ]\n" + " },\n"
+ " \"currentDate\": \"2024-02-24\",\n"
+ " \"sqlGenType\": \"1_pass_self_consistency\"\n"
+ " },\n" + " \"request\": null\n" + "}";
@Test
void doCorrect() throws JsonProcessingException {
@@ -52,9 +46,8 @@ class SchemaCorrectorTest {
ObjectMapper objectMapper = new ObjectMapper();
ParseResult parseResult = objectMapper.readValue(json, ParseResult.class);
String sql =
"select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' "
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' "
+ "and 商务组 = 'xxx' order by 播放量 desc limit 10";
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
sqlInfo.setParsedS2SQL(sql);

View File

@@ -42,8 +42,7 @@ class SelectCorrectorTest {
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.correct(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'",
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
}

View File

@@ -17,9 +17,8 @@ class TimeCorrectorTest {
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
// 1.数据日期 <=
String sql =
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
@@ -30,9 +29,8 @@ class TimeCorrectorTest {
sqlInfo.getCorrectedS2SQL());
// 2.数据日期 <
sql =
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
@@ -42,9 +40,8 @@ class TimeCorrectorTest {
sqlInfo.getCorrectedS2SQL());
// 3.数据日期 >=
sql =
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
@@ -54,9 +51,8 @@ class TimeCorrectorTest {
sqlInfo.getCorrectedS2SQL());
// 4.数据日期 >
sql =
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
@@ -70,14 +66,12 @@ class TimeCorrectorTest {
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
Assert.assertEquals(
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
Assert.assertEquals("SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1",
sqlInfo.getCorrectedS2SQL());
// 6. 数据日期-月 <=
sql =
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1";
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);
@@ -87,9 +81,8 @@ class TimeCorrectorTest {
sqlInfo.getCorrectedS2SQL());
// 7. 数据日期-月 >
sql =
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1";
sqlInfo.setCorrectedS2SQL(sql);
corrector.doCorrect(chatQueryContext, semanticParseInfo);

View File

@@ -16,9 +16,8 @@ class WhereCorrectorTest {
void addQueryFilter() {
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
SqlInfo sqlInfo = new SqlInfo();
String sql =
"SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 "
+ "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1";
sqlInfo.setCorrectedS2SQL(sql);
semanticParseInfo.setSqlInfo(sqlInfo);
@@ -56,8 +55,7 @@ class WhereCorrectorTest {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
Assert.assertEquals(
correctS2SQL,
Assert.assertEquals(correctS2SQL,
"SELECT 维度1, SUM(播放量) FROM 数据库 WHERE "
+ "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND "
+ "name LIKE 'John%' AND id IN (1, 2, 3, 4) AND status GROUP BY 维度1");

View File

@@ -25,49 +25,17 @@ public class HeuristicDataSetResolverTest {
Map<Long, List<SchemaElementMatch>> dataSet2Matches =
chatQueryContext.getMapInfo().getDataSetElementMatches();
List<SchemaElementMatch> matches = Lists.newArrayList();
matches.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(1L)
.name("超音数")
.type(SchemaElementType.DATASET)
.build())
.similarity(1)
.build());
matches.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(1L)
.name("访问次数")
.type(SchemaElementType.METRIC)
.build())
.similarity(0.5)
.build());
matches.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(1L)
.name("超音数").type(SchemaElementType.DATASET).build()).similarity(1).build());
matches.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(1L)
.name("访问次数").type(SchemaElementType.METRIC).build()).similarity(0.5).build());
dataSet2Matches.put(1L, matches);
List<SchemaElementMatch> matches2 = Lists.newArrayList();
matches2.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(2L)
.name("访问用户数")
.type(SchemaElementType.METRIC)
.build())
.similarity(1)
.build());
matches2.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(2L)
.name("用户")
.type(SchemaElementType.DIMENSION)
.build())
.similarity(1)
.build());
matches2.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(2L)
.name("访问用户数").type(SchemaElementType.METRIC).build()).similarity(1).build());
matches2.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(2L)
.name("用户").type(SchemaElementType.DIMENSION).build()).similarity(1).build());
dataSet2Matches.put(2L, matches2);
Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets);
@@ -81,39 +49,15 @@ public class HeuristicDataSetResolverTest {
Map<Long, List<SchemaElementMatch>> dataSet2Matches =
chatQueryContext.getMapInfo().getDataSetElementMatches();
List<SchemaElementMatch> matches = Lists.newArrayList();
matches.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(1L)
.name("访问次数")
.type(SchemaElementType.METRIC)
.build())
.similarity(1)
.build());
matches.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(1L)
.name("访问次数").type(SchemaElementType.METRIC).build()).similarity(1).build());
dataSet2Matches.put(1L, matches);
List<SchemaElementMatch> matches2 = Lists.newArrayList();
matches2.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(2L)
.name("访问用户数")
.type(SchemaElementType.METRIC)
.build())
.similarity(0.6)
.build());
matches2.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(2L)
.name("用户")
.type(SchemaElementType.DIMENSION)
.build())
.similarity(1)
.build());
matches2.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(2L)
.name("访问用户数").type(SchemaElementType.METRIC).build()).similarity(0.6).build());
matches2.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(2L)
.name("用户").type(SchemaElementType.DIMENSION).build()).similarity(1).build());
dataSet2Matches.put(2L, matches2);
Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets);
@@ -127,49 +71,17 @@ public class HeuristicDataSetResolverTest {
Map<Long, List<SchemaElementMatch>> dataSet2Matches =
chatQueryContext.getMapInfo().getDataSetElementMatches();
List<SchemaElementMatch> matches = Lists.newArrayList();
matches.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(1L)
.name("访问次数")
.type(SchemaElementType.METRIC)
.build())
.similarity(0.8)
.build());
matches.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(1L)
.name("部门")
.type(SchemaElementType.METRIC)
.build())
.similarity(0.7)
.build());
matches.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(1L)
.name("访问次数").type(SchemaElementType.METRIC).build()).similarity(0.8).build());
matches.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(1L)
.name("部门").type(SchemaElementType.METRIC).build()).similarity(0.7).build());
dataSet2Matches.put(1L, matches);
List<SchemaElementMatch> matches2 = Lists.newArrayList();
matches2.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(2L)
.name("访问用户数")
.type(SchemaElementType.METRIC)
.build())
.similarity(0.8)
.build());
matches2.add(
SchemaElementMatch.builder()
.element(
SchemaElement.builder()
.dataSetId(2L)
.name("用户")
.type(SchemaElementType.DIMENSION)
.build())
.similarity(1)
.build());
matches2.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(2L)
.name("访问用户数").type(SchemaElementType.METRIC).build()).similarity(0.8).build());
matches2.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(2L)
.name("用户").type(SchemaElementType.DIMENSION).build()).similarity(1).build());
dataSet2Matches.put(2L, matches2);
Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets);

View File

@@ -26,13 +26,8 @@ class LLMSqlParserTest {
value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生"));
schemaValueMaps.add(value1);
SchemaElement schemaElement =
SchemaElement.builder()
.bizName("singer_name")
.name("歌手名")
.dataSetId(2L)
.schemaValueMaps(schemaValueMaps)
.build();
SchemaElement schemaElement = SchemaElement.builder().bizName("singer_name").name("歌手名")
.dataSetId(2L).schemaValueMaps(schemaValueMaps).build();
dimensions.add(schemaElement);
SchemaElement schemaElement2 =

View File

@@ -40,9 +40,7 @@ class QueryFilterParserTest {
String parse = QueryFilterParser.parse(queryFilters);
Assert.assertEquals(
parse,
"age > 30 AND name LIKE 'John%' AND id IN (1, 2, 3, 4)"
+ " AND status NOT_IN ('inactive', 'deleted')");
Assert.assertEquals(parse, "age > 30 AND name LIKE 'John%' AND id IN (1, 2, 3, 4)"
+ " AND status NOT_IN ('inactive', 'deleted')");
}
}