mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 04:57:28 +00:00
(improvement)(build) Add spotless during the build process. (#1639)
This commit is contained in:
@@ -41,17 +41,14 @@ public class ChatQueryContext {
|
||||
private Map<Long, List<Long>> modelIdToDataSetIds;
|
||||
private User user;
|
||||
private boolean saveAnswer;
|
||||
@Builder.Default
|
||||
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
||||
@Builder.Default private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
||||
private QueryFilters queryFilters;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
@JsonIgnore
|
||||
private SemanticSchema semanticSchema;
|
||||
@JsonIgnore
|
||||
private ChatWorkflowState chatWorkflowState;
|
||||
@JsonIgnore private SemanticSchema semanticSchema;
|
||||
@JsonIgnore private ChatWorkflowState chatWorkflowState;
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
@@ -59,12 +56,16 @@ public class ChatQueryContext {
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
int parseShowCount = Integer.parseInt(parserConfig.getParameterValue(ParserConfig.PARSER_SHOW_COUNT));
|
||||
candidateQueries = candidateQueries.stream()
|
||||
.sorted(Comparator.comparing(semanticQuery -> semanticQuery.getParseInfo().getScore(),
|
||||
Comparator.reverseOrder()))
|
||||
.limit(parseShowCount)
|
||||
.collect(Collectors.toList());
|
||||
int parseShowCount =
|
||||
Integer.parseInt(parserConfig.getParameterValue(ParserConfig.PARSER_SHOW_COUNT));
|
||||
candidateQueries =
|
||||
candidateQueries.stream()
|
||||
.sorted(
|
||||
Comparator.comparing(
|
||||
semanticQuery -> semanticQuery.getParseInfo().getScore(),
|
||||
Comparator.reverseOrder()))
|
||||
.limit(parseShowCount)
|
||||
.collect(Collectors.toList());
|
||||
return candidateQueries;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
@@ -9,9 +8,7 @@ import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Verify whether the SQL aggregate function is missing. If it is missing, fill it in.
|
||||
*/
|
||||
/** Verify whether the SQL aggregate function is missing. If it is missing, fill it in. */
|
||||
@Slf4j
|
||||
public class AggCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@@ -20,13 +17,14 @@ public class AggCorrector extends BaseSemanticCorrector {
|
||||
addAggregate(chatQueryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addAggregate(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
private void addAggregate(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields =
|
||||
SqlSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
}
|
||||
addAggregateToMetric(chatQueryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -23,8 +23,8 @@ import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* basic semantic correction functionality, offering common methods and an
|
||||
* abstract method called doCorrect
|
||||
* basic semantic correction functionality, offering common methods and an abstract method called
|
||||
* doCorrect
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
@@ -35,15 +35,20 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
return;
|
||||
}
|
||||
doCorrect(chatQueryContext, semanticParseInfo);
|
||||
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
||||
log.debug(
|
||||
"sqlCorrection:{} sql:{}",
|
||||
this.getClass().getSimpleName(),
|
||||
semanticParseInfo.getSqlInfo());
|
||||
} catch (Exception e) {
|
||||
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
||||
}
|
||||
}
|
||||
|
||||
public abstract void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo);
|
||||
public abstract void doCorrect(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo);
|
||||
|
||||
protected Map<String, String> getFieldNameMap(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
protected Map<String, String> getFieldNameMap(
|
||||
ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
|
||||
Map<String, String> result = getFieldNameMapFromDB(chatQueryContext, dataSetId);
|
||||
if (chatQueryContext.containsPartitionDimensions(dataSetId)) {
|
||||
@@ -58,7 +63,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
return result;
|
||||
}
|
||||
|
||||
private static Map<String, String> getFieldNameMapFromDB(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
private static Map<String, String> getFieldNameMapFromDB(
|
||||
ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
|
||||
List<SchemaElement> dbAllFields = new ArrayList<>();
|
||||
@@ -68,53 +74,6 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
// support fieldName and field alias
|
||||
return dbAllFields.stream()
|
||||
.filter(entry -> dataSetId.equals(entry.getDataSetId()))
|
||||
.flatMap(schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
})
|
||||
.collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
protected void addAggregateToMetric(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
Long dataSetId = semanticParseInfo.getDataSet().getDataSetId();
|
||||
List<SchemaElement> metrics = getMetricElements(chatQueryContext, dataSetId);
|
||||
|
||||
Map<String, String> metricToAggregate = metrics.stream()
|
||||
.map(schemaElement -> {
|
||||
if (Objects.isNull(schemaElement.getDefaultAgg())) {
|
||||
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
|
||||
}
|
||||
return schemaElement;
|
||||
}).flatMap(schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream().map(element -> Pair.of(element, schemaElement.getDefaultAgg())
|
||||
);
|
||||
}).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
|
||||
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
}
|
||||
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMetricElements(ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
return semanticSchema.getMetrics(dataSetId);
|
||||
}
|
||||
|
||||
protected Set<String> getDimensions(Long dataSetId, SemanticSchema semanticSchema) {
|
||||
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
|
||||
.flatMap(
|
||||
schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
@@ -123,26 +82,88 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
}
|
||||
).collect(Collectors.toSet());
|
||||
})
|
||||
.collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
protected void addAggregateToMetric(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
// add aggregate to all metric
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
Long dataSetId = semanticParseInfo.getDataSet().getDataSetId();
|
||||
List<SchemaElement> metrics = getMetricElements(chatQueryContext, dataSetId);
|
||||
|
||||
Map<String, String> metricToAggregate =
|
||||
metrics.stream()
|
||||
.map(
|
||||
schemaElement -> {
|
||||
if (Objects.isNull(schemaElement.getDefaultAgg())) {
|
||||
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
|
||||
}
|
||||
return schemaElement;
|
||||
})
|
||||
.flatMap(
|
||||
schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream()
|
||||
.map(
|
||||
element ->
|
||||
Pair.of(
|
||||
element,
|
||||
schemaElement.getDefaultAgg()));
|
||||
})
|
||||
.collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
|
||||
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
}
|
||||
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMetricElements(
|
||||
ChatQueryContext chatQueryContext, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
return semanticSchema.getMetrics(dataSetId);
|
||||
}
|
||||
|
||||
protected Set<String> getDimensions(Long dataSetId, SemanticSchema semanticSchema) {
|
||||
Set<String> dimensions =
|
||||
semanticSchema.getDimensions(dataSetId).stream()
|
||||
.flatMap(
|
||||
schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
protected boolean containsPartitionDimensions(ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo semanticParseInfo) {
|
||||
protected boolean containsPartitionDimensions(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
||||
return dataSetSchema.containsPartitionDimensions();
|
||||
}
|
||||
|
||||
protected void removeDateIfExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
protected void removeDateIfExist(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
Set<String> removeFieldNames = new HashSet<>();
|
||||
removeFieldNames.addAll(TimeDimensionEnum.getChNameList());
|
||||
removeFieldNames.addAll(TimeDimensionEnum.getNameList());
|
||||
Map<String, String> fieldNameMap = getFieldNameMapFromDB(chatQueryContext, semanticParseInfo.getDataSetId());
|
||||
Map<String, String> fieldNameMap =
|
||||
getFieldNameMapFromDB(chatQueryContext, semanticParseInfo.getDataSetId());
|
||||
removeFieldNames.removeIf(fieldName -> fieldNameMap.containsKey(fieldName));
|
||||
if (!CollectionUtils.isEmpty(removeFieldNames)) {
|
||||
correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
@@ -18,9 +18,7 @@ import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Group by" section in S2SQL.
|
||||
*/
|
||||
/** Perform SQL corrections on the "Group by" section in S2SQL. */
|
||||
@Slf4j
|
||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@@ -33,13 +31,14 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
addGroupByFields(chatQueryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private Boolean needAddGroupBy(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
private Boolean needAddGroupBy(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
//add dimension group by
|
||||
// add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
@@ -48,7 +47,7 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
log.debug("no need to add groupby ,existed distinct in s2sql:{}", correctS2SQL);
|
||||
return false;
|
||||
}
|
||||
//add alias field name
|
||||
// add alias field name
|
||||
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
@@ -63,33 +62,40 @@ public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
return false;
|
||||
}
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo) && !Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
String correctorAdditionalInfo =
|
||||
environment.getProperty("s2.corrector.additional.information");
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo)
|
||||
&& !Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private void addGroupByFields(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
private void addGroupByFields(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
//add dimension group by
|
||||
// add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
//add alias field name
|
||||
// add alias field name
|
||||
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
||||
List<String> selectFields = SqlSelectHelper.gePureSelectFields(correctS2SQL);
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
Set<String> groupByFields = selectFields.stream()
|
||||
.filter(field -> dimensions.contains(field))
|
||||
.filter(field -> {
|
||||
if (!CollectionUtils.isEmpty(aggregateFields) && aggregateFields.contains(field)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
Set<String> groupByFields =
|
||||
selectFields.stream()
|
||||
.filter(field -> dimensions.contains(field))
|
||||
.filter(
|
||||
field -> {
|
||||
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||
&& aggregateFields.contains(field)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo
|
||||
.getSqlInfo()
|
||||
.setCorrectedS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
@@ -17,25 +17,24 @@ import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Having" section in S2SQL.
|
||||
*/
|
||||
/** Perform SQL corrections on the "Having" section in S2SQL. */
|
||||
@Slf4j
|
||||
public class HavingCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
//add aggregate to all metric
|
||||
// add aggregate to all metric
|
||||
addHaving(chatQueryContext, semanticParseInfo);
|
||||
|
||||
//decide whether add having expression field to select
|
||||
// decide whether add having expression field to select
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
String correctorAdditionalInfo =
|
||||
environment.getProperty("s2.corrector.additional.information");
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo)
|
||||
&& Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
addHavingToSelect(semanticParseInfo);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void addHaving(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
@@ -43,13 +42,16 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
Set<String> metrics =
|
||||
semanticSchema.getMetrics(dataSet).stream()
|
||||
.map(schemaElement -> schemaElement.getName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
}
|
||||
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(), metrics);
|
||||
String havingSql =
|
||||
SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectedS2SQL(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(havingSql);
|
||||
}
|
||||
|
||||
@@ -60,10 +62,10 @@ public class HavingCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
|
||||
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
String replaceSql =
|
||||
SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -21,7 +21,8 @@ public class S2SqlDateHelper {
|
||||
if (Objects.isNull(dataSetId)) {
|
||||
return defaultDate;
|
||||
}
|
||||
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
DataSetSchema dataSetSchema =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
if (dataSetSchema == null || dataSetSchema.getTagTypeTimeDefaultConfig() == null) {
|
||||
return defaultDate;
|
||||
}
|
||||
@@ -30,13 +31,14 @@ public class S2SqlDateHelper {
|
||||
return getDefaultDate(defaultDate, tagTypeTimeDefaultConfig, partitionTimeFormat).getLeft();
|
||||
}
|
||||
|
||||
public static Pair<String, String> getStartEndDate(ChatQueryContext chatQueryContext, Long dataSetId,
|
||||
QueryType queryType) {
|
||||
public static Pair<String, String> getStartEndDate(
|
||||
ChatQueryContext chatQueryContext, Long dataSetId, QueryType queryType) {
|
||||
String defaultDate = DateUtils.getBeforeDate(0);
|
||||
if (Objects.isNull(dataSetId)) {
|
||||
return Pair.of(defaultDate, defaultDate);
|
||||
}
|
||||
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
DataSetSchema dataSetSchema =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
if (Objects.isNull(dataSetSchema)) {
|
||||
return Pair.of(defaultDate, defaultDate);
|
||||
}
|
||||
@@ -48,9 +50,8 @@ public class S2SqlDateHelper {
|
||||
return getDefaultDate(defaultDate, defaultConfig, partitionTimeFormat);
|
||||
}
|
||||
|
||||
private static Pair<String, String> getDefaultDate(String defaultDate,
|
||||
TimeDefaultConfig defaultConfig,
|
||||
String partitionTimeFormat) {
|
||||
private static Pair<String, String> getDefaultDate(
|
||||
String defaultDate, TimeDefaultConfig defaultConfig, String partitionTimeFormat) {
|
||||
if (defaultConfig == null) {
|
||||
return Pair.of(null, null);
|
||||
}
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.jsqlparser.AggregateEnum;
|
||||
import com.tencent.supersonic.common.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -27,9 +27,7 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform schema corrections on the Schema information in S2SQL.
|
||||
*/
|
||||
/** Perform schema corrections on the Schema information in S2SQL. */
|
||||
@Slf4j
|
||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@@ -49,7 +47,8 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
correctFieldName(chatQueryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void removeDateFields(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
private void removeDateFields(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
if (containsPartitionDimensions(chatQueryContext, semanticParseInfo)) {
|
||||
return;
|
||||
}
|
||||
@@ -69,8 +68,10 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
sqlInfo.setCorrectedS2SQL(replaceAlias);
|
||||
}
|
||||
|
||||
private void correctFieldName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
|
||||
private void correctFieldName(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap =
|
||||
getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectedS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
@@ -82,13 +83,20 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
|
||||
Collectors.groupingBy(LLMReq.ElementValue::getFieldValue,
|
||||
Collectors.mapping(LLMReq.ElementValue::getFieldName, Collectors.toSet())));
|
||||
Map<String, Set<String>> fieldValueToFieldNames =
|
||||
linking.stream()
|
||||
.collect(
|
||||
Collectors.groupingBy(
|
||||
LLMReq.ElementValue::getFieldValue,
|
||||
Collectors.mapping(
|
||||
LLMReq.ElementValue::getFieldName,
|
||||
Collectors.toSet())));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
|
||||
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectedS2SQL(), fieldValueToFieldNames);
|
||||
String sql =
|
||||
SqlReplaceHelper.replaceFieldNameByValue(
|
||||
sqlInfo.getCorrectedS2SQL(), fieldValueToFieldNames);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
@@ -111,24 +119,31 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> filedNameToValueMap = linking.stream().collect(
|
||||
Collectors.groupingBy(LLMReq.ElementValue::getFieldName,
|
||||
Collectors.mapping(LLMReq.ElementValue::getFieldValue, Collectors.toMap(
|
||||
oldValue -> oldValue,
|
||||
newValue -> newValue,
|
||||
(existingValue, newValue) -> newValue)
|
||||
)));
|
||||
Map<String, Map<String, String>> filedNameToValueMap =
|
||||
linking.stream()
|
||||
.collect(
|
||||
Collectors.groupingBy(
|
||||
LLMReq.ElementValue::getFieldName,
|
||||
Collectors.mapping(
|
||||
LLMReq.ElementValue::getFieldValue,
|
||||
Collectors.toMap(
|
||||
oldValue -> oldValue,
|
||||
newValue -> newValue,
|
||||
(existingValue, newValue) -> newValue))));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectedS2SQL(), filedNameToValueMap, false);
|
||||
String sql =
|
||||
SqlReplaceHelper.replaceValue(
|
||||
sqlInfo.getCorrectedS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo semanticParseInfo) {
|
||||
public void removeFilterIfNotInLinkingValue(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectedS2SQL();
|
||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
||||
List<FieldExpression> whereExpressionList =
|
||||
SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereExpressionList)) {
|
||||
return;
|
||||
}
|
||||
@@ -139,20 +154,39 @@ public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(linkingValues)) {
|
||||
linkingValues = new ArrayList<>();
|
||||
}
|
||||
Set<String> linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName())
|
||||
.collect(Collectors.toSet());
|
||||
Set<String> linkingFieldNames =
|
||||
linkingValues.stream()
|
||||
.map(linking -> linking.getFieldName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Set<String> removeFieldNames = whereExpressionList.stream()
|
||||
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
|
||||
.filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension(fieldExpression.getFieldName()))
|
||||
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator()))
|
||||
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))
|
||||
.filter(fieldExpression -> !DateUtils.isAnyDateString(fieldExpression.getFieldValue().toString()))
|
||||
.filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName()))
|
||||
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
|
||||
Set<String> removeFieldNames =
|
||||
whereExpressionList.stream()
|
||||
.filter(
|
||||
fieldExpression ->
|
||||
StringUtils.isBlank(fieldExpression.getFunction()))
|
||||
.filter(
|
||||
fieldExpression ->
|
||||
!TimeDimensionEnum.containsTimeDimension(
|
||||
fieldExpression.getFieldName()))
|
||||
.filter(
|
||||
fieldExpression ->
|
||||
FilterOperatorEnum.EQUALS
|
||||
.getValue()
|
||||
.equals(fieldExpression.getOperator()))
|
||||
.filter(
|
||||
fieldExpression ->
|
||||
dimensions.contains(fieldExpression.getFieldName()))
|
||||
.filter(
|
||||
fieldExpression ->
|
||||
!DateUtils.isAnyDateString(
|
||||
fieldExpression.getFieldValue().toString()))
|
||||
.filter(
|
||||
fieldExpression ->
|
||||
!linkingFieldNames.contains(fieldExpression.getFieldName()))
|
||||
.map(fieldExpression -> fieldExpression.getFieldName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
|
||||
sqlInfo.setCorrectedS2SQL(sql);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
@@ -23,9 +22,7 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Select" section in S2SQL.
|
||||
*/
|
||||
/** Perform SQL corrections on the "Select" section in S2SQL. */
|
||||
@Slf4j
|
||||
public class SelectCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@@ -36,7 +33,8 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add
|
||||
// fields to select.
|
||||
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||
&& !CollectionUtils.isEmpty(selectFields)
|
||||
&& aggregateFields.size() == selectFields.size()) {
|
||||
@@ -47,55 +45,65 @@ public class SelectCorrector extends BaseSemanticCorrector {
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(querySql);
|
||||
}
|
||||
|
||||
protected String addFieldsToSelect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo,
|
||||
String correctS2SQL) {
|
||||
protected String addFieldsToSelect(
|
||||
ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo semanticParseInfo,
|
||||
String correctS2SQL) {
|
||||
correctS2SQL = addTagDefaultFields(chatQueryContext, semanticParseInfo, correctS2SQL);
|
||||
|
||||
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||
|
||||
//decide whether add order by expression field to select
|
||||
// decide whether add order by expression field to select
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorAdditionalInfo = environment.getProperty(ADDITIONAL_INFORMATION);
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo)
|
||||
&& Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
||||
}
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
|
||||
return correctS2SQL;
|
||||
}
|
||||
needAddFields.removeAll(selectFields);
|
||||
String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
String addFieldsToSelectSql =
|
||||
SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(addFieldsToSelectSql);
|
||||
return addFieldsToSelectSql;
|
||||
}
|
||||
|
||||
private String addTagDefaultFields(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo,
|
||||
String correctS2SQL) {
|
||||
//If it is in DETAIL mode and select *, add default metrics and dimensions.
|
||||
private String addTagDefaultFields(
|
||||
ChatQueryContext chatQueryContext,
|
||||
SemanticParseInfo semanticParseInfo,
|
||||
String correctS2SQL) {
|
||||
// If it is in DETAIL mode and select *, add default metrics and dimensions.
|
||||
boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL);
|
||||
if (!(hasAsterisk && QueryType.DETAIL.equals(semanticParseInfo.getQueryType()))) {
|
||||
return correctS2SQL;
|
||||
}
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
DataSetSchema dataSetSchema =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
Set<String> needAddDefaultFields = new HashSet<>();
|
||||
if (Objects.nonNull(dataSetSchema)) {
|
||||
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultMetrics())) {
|
||||
Set<String> metrics = dataSetSchema.getTagDefaultMetrics()
|
||||
.stream().map(schemaElement -> schemaElement.getName())
|
||||
.collect(Collectors.toSet());
|
||||
Set<String> metrics =
|
||||
dataSetSchema.getTagDefaultMetrics().stream()
|
||||
.map(schemaElement -> schemaElement.getName())
|
||||
.collect(Collectors.toSet());
|
||||
needAddDefaultFields.addAll(metrics);
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultDimensions())) {
|
||||
Set<String> dimensions = dataSetSchema.getTagDefaultDimensions()
|
||||
.stream().map(schemaElement -> schemaElement.getName())
|
||||
.collect(Collectors.toSet());
|
||||
Set<String> dimensions =
|
||||
dataSetSchema.getTagDefaultDimensions().stream()
|
||||
.map(schemaElement -> schemaElement.getName())
|
||||
.collect(Collectors.toSet());
|
||||
needAddDefaultFields.addAll(dimensions);
|
||||
}
|
||||
}
|
||||
// remove * in sql and add default fields.
|
||||
if (!CollectionUtils.isEmpty(needAddDefaultFields)) {
|
||||
correctS2SQL = SqlRemoveHelper.removeAsteriskAndAddFields(correctS2SQL, needAddDefaultFields);
|
||||
correctS2SQL =
|
||||
SqlRemoveHelper.removeAsteriskAndAddFields(correctS2SQL, needAddDefaultFields);
|
||||
}
|
||||
return correctS2SQL;
|
||||
}
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
|
||||
/**
|
||||
* A semantic corrector checks validity of extracted semantic information and
|
||||
* performs correction and optimization if needed.
|
||||
* A semantic corrector checks validity of extracted semantic information and performs correction
|
||||
* and optimization if needed.
|
||||
*/
|
||||
public interface SemanticCorrector {
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.DateVisitor.DateBoundInfo;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlDateSelectHelper;
|
||||
@@ -20,9 +19,7 @@ import org.springframework.util.CollectionUtils;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the time in S2SQL.
|
||||
*/
|
||||
/** Perform SQL corrections on the time in S2SQL. */
|
||||
@Slf4j
|
||||
public class TimeCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@@ -36,11 +33,13 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
||||
addLowerBoundDate(semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addDateIfNotExist(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
private void addDateIfNotExist(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
DataSetSchema dataSetSchema =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
if (Objects.isNull(dataSetSchema)
|
||||
|| Objects.isNull(dataSetSchema.getPartitionDimension())
|
||||
|| Objects.isNull(dataSetSchema.getPartitionDimension().getName())
|
||||
@@ -49,16 +48,22 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
||||
}
|
||||
String partitionDimension = dataSetSchema.getPartitionDimension().getName();
|
||||
if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(partitionDimension)) {
|
||||
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(chatQueryContext, dataSetId,
|
||||
semanticParseInfo.getQueryType());
|
||||
Pair<String, String> startEndDate =
|
||||
S2SqlDateHelper.getStartEndDate(
|
||||
chatQueryContext, dataSetId, semanticParseInfo.getQueryType());
|
||||
|
||||
if (isValidDateRange(startEndDate)) {
|
||||
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
String startDateLeft = startEndDate.getLeft();
|
||||
String endDateRight = startEndDate.getRight();
|
||||
|
||||
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )",
|
||||
partitionDimension, startDateLeft, partitionDimension, endDateRight);
|
||||
String condExpr =
|
||||
String.format(
|
||||
" ( %s >= '%s' and %s <= '%s' )",
|
||||
partitionDimension,
|
||||
startDateLeft,
|
||||
partitionDimension,
|
||||
endDateRight);
|
||||
correctS2SQL = addConditionToSQL(correctS2SQL, condExpr);
|
||||
}
|
||||
}
|
||||
@@ -94,4 +99,4 @@ public class TimeCorrector extends BaseSemanticCorrector {
|
||||
return sql;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
@@ -9,11 +8,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryFilterParser;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
@@ -21,9 +15,13 @@ import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||
*/
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/** Perform SQL corrections on the "Where" section in S2SQL. */
|
||||
@Slf4j
|
||||
public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@@ -33,7 +31,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
updateFieldValueByTechName(chatQueryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
protected void addQueryFilter(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
protected void addQueryFilter(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters());
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
|
||||
@@ -56,7 +55,8 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
return QueryFilterParser.parse(queryFilters);
|
||||
}
|
||||
|
||||
private void updateFieldValueByTechName(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
private void updateFieldValueByTechName(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
|
||||
@@ -64,35 +64,61 @@ public class WhereCorrector extends BaseSemanticCorrector {
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
}
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName =
|
||||
getAliasAndBizNameToTechName(dimensions);
|
||||
String correctedS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL();
|
||||
String replaceSql = SqlReplaceHelper.replaceValue(correctedS2SQL, aliasAndBizNameToTechName);
|
||||
String replaceSql =
|
||||
SqlReplaceHelper.replaceValue(correctedS2SQL, aliasAndBizNameToTechName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectedS2SQL(replaceSql);
|
||||
}
|
||||
|
||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(
|
||||
List<SchemaElement> dimensions) {
|
||||
return dimensions.stream()
|
||||
.filter(dimension -> Objects.nonNull(dimension)
|
||||
&& StringUtils.isNotEmpty(dimension.getName())
|
||||
&& !CollectionUtils.isEmpty(dimension.getSchemaValueMaps()))
|
||||
.collect(Collectors.toMap(
|
||||
SchemaElement::getName,
|
||||
dimension -> dimension.getSchemaValueMaps().stream()
|
||||
.filter(valueMap -> Objects.nonNull(valueMap)
|
||||
&& StringUtils.isNotEmpty(valueMap.getTechName()))
|
||||
.flatMap(valueMap -> {
|
||||
Map<String, String> map = new HashMap<>();
|
||||
if (StringUtils.isNotEmpty(valueMap.getBizName())) {
|
||||
map.put(valueMap.getBizName(), valueMap.getTechName());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
|
||||
valueMap.getAlias().stream()
|
||||
.filter(StringUtils::isNotEmpty)
|
||||
.forEach(alias -> map.put(alias, valueMap.getTechName()));
|
||||
}
|
||||
return map.entrySet().stream();
|
||||
})
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))
|
||||
));
|
||||
.filter(
|
||||
dimension ->
|
||||
Objects.nonNull(dimension)
|
||||
&& StringUtils.isNotEmpty(dimension.getName())
|
||||
&& !CollectionUtils.isEmpty(dimension.getSchemaValueMaps()))
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
SchemaElement::getName,
|
||||
dimension ->
|
||||
dimension.getSchemaValueMaps().stream()
|
||||
.filter(
|
||||
valueMap ->
|
||||
Objects.nonNull(valueMap)
|
||||
&& StringUtils.isNotEmpty(
|
||||
valueMap
|
||||
.getTechName()))
|
||||
.flatMap(
|
||||
valueMap -> {
|
||||
Map<String, String> map =
|
||||
new HashMap<>();
|
||||
if (StringUtils.isNotEmpty(
|
||||
valueMap.getBizName())) {
|
||||
map.put(
|
||||
valueMap.getBizName(),
|
||||
valueMap.getTechName());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(
|
||||
valueMap.getAlias())) {
|
||||
valueMap.getAlias().stream()
|
||||
.filter(
|
||||
StringUtils
|
||||
::isNotEmpty)
|
||||
.forEach(
|
||||
alias ->
|
||||
map.put(
|
||||
alias,
|
||||
valueMap
|
||||
.getTechName()));
|
||||
}
|
||||
return map.entrySet().stream();
|
||||
})
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
Map.Entry::getKey,
|
||||
Map.Entry::getValue))));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,5 +18,4 @@ public class DataSetInfoStat implements Serializable {
|
||||
private long dimensionDataSetCount;
|
||||
|
||||
private long dimensionValueDataSetCount;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,4 +27,4 @@ public class DatabaseMapResult extends MapResult {
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(name, schemaElement);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
public enum DictUpdateMode {
|
||||
|
||||
OFFLINE_FULL("OFFLINE_FULL"),
|
||||
OFFLINE_MODEL("OFFLINE_MODEL"),
|
||||
REALTIME_ADD("REALTIME_ADD"),
|
||||
@@ -27,5 +25,4 @@ public enum DictUpdateMode {
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import java.util.Objects;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
/***
|
||||
* word nature
|
||||
*/
|
||||
import java.util.Objects;
|
||||
|
||||
/** * word nature */
|
||||
@Data
|
||||
@ToString
|
||||
public class DictWord {
|
||||
@@ -24,7 +23,8 @@ public class DictWord {
|
||||
return false;
|
||||
}
|
||||
DictWord that = (DictWord) o;
|
||||
return Objects.equals(word, that.word) && Objects.equals(natureWithFrequency, that.natureWithFrequency);
|
||||
return Objects.equals(word, that.word)
|
||||
&& Objects.equals(natureWithFrequency, that.natureWithFrequency);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.hankcs.hanlp.dictionary.CoreDictionary;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -12,38 +13,52 @@ import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
* Dictionary Attribute Util
|
||||
*/
|
||||
/** Dictionary Attribute Util */
|
||||
public class DictionaryAttributeUtil {
|
||||
|
||||
public static CoreDictionary.Attribute getAttribute(CoreDictionary.Attribute old, CoreDictionary.Attribute add) {
|
||||
public static CoreDictionary.Attribute getAttribute(
|
||||
CoreDictionary.Attribute old, CoreDictionary.Attribute add) {
|
||||
Map<Nature, Integer> map = new HashMap<>();
|
||||
Map<Nature, String> originalMap = new HashMap<>();
|
||||
IntStream.range(0, old.nature.length).boxed().forEach(i -> {
|
||||
map.put(old.nature[i], old.frequency[i]);
|
||||
if (Objects.nonNull(old.originals)) {
|
||||
originalMap.put(old.nature[i], old.originals[i]);
|
||||
}
|
||||
});
|
||||
IntStream.range(0, add.nature.length).boxed().forEach(i -> {
|
||||
map.put(add.nature[i], add.frequency[i]);
|
||||
if (Objects.nonNull(add.originals)) {
|
||||
originalMap.put(add.nature[i], add.originals[i]);
|
||||
}
|
||||
});
|
||||
List<Map.Entry<Nature, Integer>> list = new LinkedList<Map.Entry<Nature, Integer>>(map.entrySet());
|
||||
Collections.sort(list, new Comparator<Map.Entry<Nature, Integer>>() {
|
||||
public int compare(Map.Entry<Nature, Integer> o1, Map.Entry<Nature, Integer> o2) {
|
||||
return o2.getValue() - o1.getValue();
|
||||
}
|
||||
});
|
||||
String[] originals = list.stream().map(l -> originalMap.get(l.getKey())).toArray(String[]::new);
|
||||
CoreDictionary.Attribute attribute = new CoreDictionary.Attribute(
|
||||
list.stream().map(i -> i.getKey()).collect(Collectors.toList()).toArray(new Nature[0]),
|
||||
list.stream().map(i -> i.getValue()).mapToInt(Integer::intValue).toArray(),
|
||||
originals,
|
||||
list.stream().map(i -> i.getValue()).findFirst().get());
|
||||
IntStream.range(0, old.nature.length)
|
||||
.boxed()
|
||||
.forEach(
|
||||
i -> {
|
||||
map.put(old.nature[i], old.frequency[i]);
|
||||
if (Objects.nonNull(old.originals)) {
|
||||
originalMap.put(old.nature[i], old.originals[i]);
|
||||
}
|
||||
});
|
||||
IntStream.range(0, add.nature.length)
|
||||
.boxed()
|
||||
.forEach(
|
||||
i -> {
|
||||
map.put(add.nature[i], add.frequency[i]);
|
||||
if (Objects.nonNull(add.originals)) {
|
||||
originalMap.put(add.nature[i], add.originals[i]);
|
||||
}
|
||||
});
|
||||
List<Map.Entry<Nature, Integer>> list =
|
||||
new LinkedList<Map.Entry<Nature, Integer>>(map.entrySet());
|
||||
Collections.sort(
|
||||
list,
|
||||
new Comparator<Map.Entry<Nature, Integer>>() {
|
||||
public int compare(
|
||||
Map.Entry<Nature, Integer> o1, Map.Entry<Nature, Integer> o2) {
|
||||
return o2.getValue() - o1.getValue();
|
||||
}
|
||||
});
|
||||
String[] originals =
|
||||
list.stream().map(l -> originalMap.get(l.getKey())).toArray(String[]::new);
|
||||
CoreDictionary.Attribute attribute =
|
||||
new CoreDictionary.Attribute(
|
||||
list.stream()
|
||||
.map(i -> i.getKey())
|
||||
.collect(Collectors.toList())
|
||||
.toArray(new Nature[0]),
|
||||
list.stream().map(i -> i.getValue()).mapToInt(Integer::intValue).toArray(),
|
||||
originals,
|
||||
list.stream().map(i -> i.getValue()).findFirst().get());
|
||||
return attribute;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import java.util.Map;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class EmbeddingResult extends MapResult {
|
||||
@@ -31,4 +32,4 @@ public class EmbeddingResult extends MapResult {
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import com.hankcs.hanlp.corpus.io.IIOAdapter;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.net.URI;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.net.URI;
|
||||
|
||||
@Slf4j
|
||||
public class HadoopFileIOAdapter implements IIOAdapter {
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class HanlpMapResult extends MapResult {
|
||||
@@ -29,7 +30,8 @@ public class HanlpMapResult extends MapResult {
|
||||
return false;
|
||||
}
|
||||
HanlpMapResult hanlpMapResult = (HanlpMapResult) o;
|
||||
return Objects.equal(name, hanlpMapResult.name) && Objects.equal(natures, hanlpMapResult.natures);
|
||||
return Objects.equal(name, hanlpMapResult.name)
|
||||
&& Objects.equal(natures, hanlpMapResult.natures);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -40,5 +42,4 @@ public class HanlpMapResult extends MapResult {
|
||||
public void setOffset(int offset) {
|
||||
this.offset = offset;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,17 +17,25 @@ public class KnowledgeBaseService {
|
||||
|
||||
public void updateSemanticKnowledge(List<DictWord> natures) {
|
||||
|
||||
List<DictWord> prefixes = natures.stream()
|
||||
.filter(entry -> !entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getType()))
|
||||
.collect(Collectors.toList());
|
||||
List<DictWord> prefixes =
|
||||
natures.stream()
|
||||
.filter(
|
||||
entry ->
|
||||
!entry.getNatureWithFrequency()
|
||||
.contains(DictWordType.SUFFIX.getType()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
for (DictWord nature : prefixes) {
|
||||
HanlpHelper.addToCustomDictionary(nature);
|
||||
}
|
||||
|
||||
List<DictWord> suffixes = natures.stream()
|
||||
.filter(entry -> entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getType()))
|
||||
.collect(Collectors.toList());
|
||||
List<DictWord> suffixes =
|
||||
natures.stream()
|
||||
.filter(
|
||||
entry ->
|
||||
entry.getNatureWithFrequency()
|
||||
.contains(DictWordType.SUFFIX.getType()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
SearchService.loadSuffix(suffixes);
|
||||
}
|
||||
@@ -56,26 +64,35 @@ public class KnowledgeBaseService {
|
||||
return HanlpHelper.getTerms(text, modelIdToDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> prefixSearch(String key, int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
public List<HanlpMapResult> prefixSearch(
|
||||
String key,
|
||||
int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return prefixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> prefixSearchByModel(String key, int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
public List<HanlpMapResult> prefixSearchByModel(
|
||||
String key,
|
||||
int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long,
|
||||
List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
public List<HanlpMapResult> suffixSearch(
|
||||
String key,
|
||||
int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return suffixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> suffixSearchByModel(String key, int limit, Map<Long,
|
||||
List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
public List<HanlpMapResult> suffixSearchByModel(
|
||||
String key,
|
||||
int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return SearchService.suffixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import java.io.Serializable;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class MapResult implements Serializable {
|
||||
|
||||
protected String name;
|
||||
protected String detectWord;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,27 +26,30 @@ import java.util.stream.Stream;
|
||||
@Slf4j
|
||||
public class MetaEmbeddingService {
|
||||
|
||||
@Autowired
|
||||
private EmbeddingService embeddingService;
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
@Autowired private EmbeddingService embeddingService;
|
||||
@Autowired private EmbeddingConfig embeddingConfig;
|
||||
|
||||
public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int num,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
public List<RetrieveQueryResult> retrieveQuery(
|
||||
RetrieveQuery retrieveQuery,
|
||||
int num,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
// dataSetIds->modelIds
|
||||
Set<Long> allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
if (CollectionUtils.isNotEmpty(allModels)) {
|
||||
Map<String, Object> filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", allModels.stream()
|
||||
.map(modelId -> modelId + DictWordType.NATURE_SPILT)
|
||||
.collect(Collectors.toList()));
|
||||
filterCondition.put(
|
||||
"modelId",
|
||||
allModels.stream()
|
||||
.map(modelId -> modelId + DictWordType.NATURE_SPILT)
|
||||
.collect(Collectors.toList()));
|
||||
retrieveQuery.setFilterCondition(filterCondition);
|
||||
}
|
||||
|
||||
String collectionName = embeddingConfig.getMetaCollectionName();
|
||||
List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(collectionName, retrieveQuery, num);
|
||||
List<RetrieveQueryResult> resultList =
|
||||
embeddingService.retrieveQuery(collectionName, retrieveQuery, num);
|
||||
if (CollectionUtils.isEmpty(resultList)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
@@ -57,31 +60,44 @@ public class MetaEmbeddingService {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static RetrieveQueryResult getRetrieveQueryResult(Map<Long,
|
||||
List<Long>> modelIdToDataSetIds, RetrieveQueryResult result) {
|
||||
private static RetrieveQueryResult getRetrieveQueryResult(
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, RetrieveQueryResult result) {
|
||||
List<Retrieval> retrievals = result.getRetrieval();
|
||||
if (CollectionUtils.isEmpty(retrievals)) {
|
||||
return result;
|
||||
}
|
||||
// Process each Retrieval object.
|
||||
List<Retrieval> updatedRetrievals = retrievals.stream()
|
||||
.flatMap(retrieval -> {
|
||||
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
|
||||
List<Long> dataSetIds = modelIdToDataSetIds.get(modelId);
|
||||
List<Retrieval> updatedRetrievals =
|
||||
retrievals.stream()
|
||||
.flatMap(
|
||||
retrieval -> {
|
||||
Long modelId =
|
||||
Retrieval.getLongId(
|
||||
retrieval.getMetadata().get("modelId"));
|
||||
List<Long> dataSetIds = modelIdToDataSetIds.get(modelId);
|
||||
|
||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return Stream.of(retrieval);
|
||||
}
|
||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return Stream.of(retrieval);
|
||||
}
|
||||
|
||||
return dataSetIds.stream().map(dataSetId -> {
|
||||
Retrieval newRetrieval = new Retrieval();
|
||||
BeanUtils.copyProperties(retrieval, newRetrieval);
|
||||
newRetrieval.getMetadata().putIfAbsent("dataSetId", dataSetId + Constants.UNDERLINE);
|
||||
return newRetrieval;
|
||||
});
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
return dataSetIds.stream()
|
||||
.map(
|
||||
dataSetId -> {
|
||||
Retrieval newRetrieval = new Retrieval();
|
||||
BeanUtils.copyProperties(
|
||||
retrieval, newRetrieval);
|
||||
newRetrieval
|
||||
.getMetadata()
|
||||
.putIfAbsent(
|
||||
"dataSetId",
|
||||
dataSetId
|
||||
+ Constants
|
||||
.UNDERLINE);
|
||||
return newRetrieval;
|
||||
});
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
result.setRetrieval(updatedRetrievals);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import static com.hankcs.hanlp.utility.Predefine.logger;
|
||||
|
||||
import com.hankcs.hanlp.HanLP;
|
||||
import com.hankcs.hanlp.collection.trie.DoubleArrayTrie;
|
||||
import com.hankcs.hanlp.collection.trie.bintrie.BinTrie;
|
||||
@@ -34,11 +32,14 @@ import java.util.PriorityQueue;
|
||||
import java.util.TreeMap;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
import static com.hankcs.hanlp.utility.Predefine.logger;
|
||||
|
||||
public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
|
||||
public static int MAX_SIZE = 10;
|
||||
public static Boolean removeDuplicates = true;
|
||||
public static ConcurrentHashMap<String, PriorityQueue<Term>> NATURE_TO_VALUES = new ConcurrentHashMap<>();
|
||||
public static ConcurrentHashMap<String, PriorityQueue<Term>> NATURE_TO_VALUES =
|
||||
new ConcurrentHashMap<>();
|
||||
private static boolean addToSuggesterTrie = true;
|
||||
|
||||
public MultiCustomDictionary() {
|
||||
@@ -49,8 +50,9 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
super(path);
|
||||
}
|
||||
|
||||
/***
|
||||
* load dictionary
|
||||
/**
|
||||
* * load dictionary
|
||||
*
|
||||
* @param path
|
||||
* @param defaultNature
|
||||
* @param map
|
||||
@@ -58,15 +60,20 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
* @param addToSuggeterTrie
|
||||
* @return
|
||||
*/
|
||||
public static boolean load(String path, Nature defaultNature, TreeMap<String, CoreDictionary.Attribute> map,
|
||||
LinkedHashSet<Nature> customNatureCollector, boolean addToSuggeterTrie) {
|
||||
public static boolean load(
|
||||
String path,
|
||||
Nature defaultNature,
|
||||
TreeMap<String, CoreDictionary.Attribute> map,
|
||||
LinkedHashSet<Nature> customNatureCollector,
|
||||
boolean addToSuggeterTrie) {
|
||||
try {
|
||||
String splitter = "\\s";
|
||||
if (path.endsWith(".csv")) {
|
||||
splitter = ",";
|
||||
}
|
||||
|
||||
BufferedReader br = new BufferedReader(new InputStreamReader(IOUtil.newInputStream(path), "UTF-8"));
|
||||
BufferedReader br =
|
||||
new BufferedReader(new InputStreamReader(IOUtil.newInputStream(path), "UTF-8"));
|
||||
boolean firstLine = true;
|
||||
|
||||
while (true) {
|
||||
@@ -105,14 +112,15 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
attribute = new CoreDictionary.Attribute(natureCount);
|
||||
|
||||
for (int i = 0; i < natureCount; ++i) {
|
||||
attribute.nature[i] = LexiconUtility.convertStringToNature(param[1 + 2 * i],
|
||||
customNatureCollector);
|
||||
attribute.nature[i] =
|
||||
LexiconUtility.convertStringToNature(
|
||||
param[1 + 2 * i], customNatureCollector);
|
||||
attribute.frequency[i] = Integer.parseInt(param[2 + 2 * i]);
|
||||
attribute.originals[i] = original;
|
||||
attribute.totalFrequency += attribute.frequency[i];
|
||||
}
|
||||
}
|
||||
//attribute.original = original;
|
||||
// attribute.original = original;
|
||||
|
||||
if (removeDuplicates && map.containsKey(word)) {
|
||||
attribute = DictionaryAttributeUtil.getAttribute(map.get(word), attribute);
|
||||
@@ -125,8 +133,10 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
Nature nature = attribute.nature[i];
|
||||
PriorityQueue<Term> priorityQueue = NATURE_TO_VALUES.get(nature.toString());
|
||||
if (Objects.isNull(priorityQueue)) {
|
||||
priorityQueue = new PriorityQueue<>(MAX_SIZE,
|
||||
Comparator.comparingInt(Term::getFrequency).reversed());
|
||||
priorityQueue =
|
||||
new PriorityQueue<>(
|
||||
MAX_SIZE,
|
||||
Comparator.comparingInt(Term::getFrequency).reversed());
|
||||
NATURE_TO_VALUES.put(nature.toString(), priorityQueue);
|
||||
}
|
||||
Term term = new Term(word, nature);
|
||||
@@ -150,14 +160,19 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
return false;
|
||||
} else {
|
||||
logger.info(
|
||||
"自定义词典加载成功:" + this.dat.size() + "个词条,耗时" + (System.currentTimeMillis() - start) + "ms");
|
||||
"自定义词典加载成功:"
|
||||
+ this.dat.size()
|
||||
+ "个词条,耗时"
|
||||
+ (System.currentTimeMillis() - start)
|
||||
+ "ms");
|
||||
this.path = path;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/***
|
||||
* load main dictionary
|
||||
/**
|
||||
* * load main dictionary
|
||||
*
|
||||
* @param mainPath
|
||||
* @param path
|
||||
* @param dat
|
||||
@@ -165,9 +180,12 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
* @param addToSuggestTrie
|
||||
* @return
|
||||
*/
|
||||
public static boolean loadMainDictionary(String mainPath, String[] path,
|
||||
DoubleArrayTrie<CoreDictionary.Attribute> dat, boolean isCache,
|
||||
boolean addToSuggestTrie) {
|
||||
public static boolean loadMainDictionary(
|
||||
String mainPath,
|
||||
String[] path,
|
||||
DoubleArrayTrie<CoreDictionary.Attribute> dat,
|
||||
boolean isCache,
|
||||
boolean addToSuggestTrie) {
|
||||
logger.info("自定义词典开始加载:" + mainPath);
|
||||
if (loadDat(mainPath, dat)) {
|
||||
return true;
|
||||
@@ -186,7 +204,9 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
p = file.getParent() + File.separator + fileName.substring(0, cut);
|
||||
|
||||
try {
|
||||
defaultNature = LexiconUtility.convertStringToNature(nature, customNatureCollector);
|
||||
defaultNature =
|
||||
LexiconUtility.convertStringToNature(
|
||||
nature, customNatureCollector);
|
||||
} catch (Exception var16) {
|
||||
logger.severe("配置文件【" + p + "】写错了!" + var16);
|
||||
continue;
|
||||
@@ -194,7 +214,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
}
|
||||
|
||||
logger.info("以默认词性[" + defaultNature + "]加载自定义词典" + p + "中……");
|
||||
boolean success = load(p, defaultNature, map, customNatureCollector, addToSuggestTrie);
|
||||
boolean success =
|
||||
load(p, defaultNature, map, customNatureCollector, addToSuggestTrie);
|
||||
if (!success) {
|
||||
logger.warning("失败:" + p);
|
||||
}
|
||||
@@ -214,13 +235,16 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
// 缓存成dat文件,下次加载会快很多
|
||||
logger.info("正在缓存词典为dat文件……");
|
||||
// 缓存值文件
|
||||
List<CoreDictionary.Attribute> attributeList = new LinkedList<CoreDictionary.Attribute>();
|
||||
List<CoreDictionary.Attribute> attributeList =
|
||||
new LinkedList<CoreDictionary.Attribute>();
|
||||
for (Map.Entry<String, CoreDictionary.Attribute> entry : map.entrySet()) {
|
||||
attributeList.add(entry.getValue());
|
||||
}
|
||||
|
||||
DataOutputStream out = new DataOutputStream(
|
||||
new BufferedOutputStream(IOUtil.newOutputStream(mainPath + ".bin")));
|
||||
DataOutputStream out =
|
||||
new DataOutputStream(
|
||||
new BufferedOutputStream(
|
||||
IOUtil.newOutputStream(mainPath + ".bin")));
|
||||
if (customNatureCollector.isEmpty()) {
|
||||
for (int i = Nature.begin.ordinal() + 1; i < Nature.values().length; ++i) {
|
||||
Nature nature = Nature.values()[i];
|
||||
@@ -247,7 +271,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
logger.severe("自定义词典" + mainPath + "读取错误!" + var18);
|
||||
return false;
|
||||
} catch (Exception var19) {
|
||||
logger.warning("自定义词典" + mainPath + "缓存失败!\n" + TextUtility.exceptionToString(var19));
|
||||
logger.warning(
|
||||
"自定义词典" + mainPath + "缓存失败!\n" + TextUtility.exceptionToString(var19));
|
||||
}
|
||||
|
||||
return true;
|
||||
@@ -262,7 +287,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
return loadDat(path, HanLP.Config.CustomDictionaryPath, dat);
|
||||
}
|
||||
|
||||
public static boolean loadDat(String path, String[] customDicPath, DoubleArrayTrie<CoreDictionary.Attribute> dat) {
|
||||
public static boolean loadDat(
|
||||
String path, String[] customDicPath, DoubleArrayTrie<CoreDictionary.Attribute> dat) {
|
||||
try {
|
||||
if (HanLP.Config.CustomDictionaryAutoRefreshCache
|
||||
&& DynamicCustomDictionary.isDicNeedUpdate(path, customDicPath)) {
|
||||
@@ -348,11 +374,11 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
IOUtil.deleteFile(this.path[0] + ".bin");
|
||||
Boolean loadCacheOk = this.loadDat(this.path[0], this.path, this.dat);
|
||||
if (!loadCacheOk) {
|
||||
return this.loadMainDictionary(this.path[0], this.path, this.dat, true, addToSuggesterTrie);
|
||||
return this.loadMainDictionary(
|
||||
this.path[0], this.path, this.dat, true, addToSuggesterTrie);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
||||
}
|
||||
|
||||
public synchronized boolean insert(String word, String natureWithFrequency) {
|
||||
@@ -362,8 +388,10 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
if (HanLP.Config.Normalization) {
|
||||
word = CharTable.convert(word);
|
||||
}
|
||||
CoreDictionary.Attribute att = natureWithFrequency == null ? new CoreDictionary.Attribute(Nature.nz, 1)
|
||||
: CoreDictionary.Attribute.create(natureWithFrequency);
|
||||
CoreDictionary.Attribute att =
|
||||
natureWithFrequency == null
|
||||
? new CoreDictionary.Attribute(Nature.nz, 1)
|
||||
: CoreDictionary.Attribute.create(natureWithFrequency);
|
||||
boolean isLetters = isLetters(word);
|
||||
word = getWordBySpace(word);
|
||||
String original = null;
|
||||
@@ -382,7 +410,7 @@ public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
if (this.trie == null) {
|
||||
this.trie = new BinTrie();
|
||||
}
|
||||
//att.original = original;
|
||||
// att.original = original;
|
||||
att.setOriginals(original);
|
||||
if (this.trie.containsKey(word)) {
|
||||
att = DictionaryAttributeUtil.getAttribute(this.trie.get(word), att);
|
||||
|
||||
@@ -36,85 +36,130 @@ public class SearchService {
|
||||
suffixTrie = new BinTrie<>();
|
||||
}
|
||||
|
||||
/***
|
||||
* prefix Search
|
||||
/**
|
||||
* * prefix Search
|
||||
*
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
public static List<HanlpMapResult> prefixSearch(
|
||||
String key,
|
||||
int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return prefixSearch(key, limit, trie, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit, BinTrie<List<String>> binTrie,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
public static List<HanlpMapResult> prefixSearch(
|
||||
String key,
|
||||
int limit,
|
||||
BinTrie<List<String>> binTrie,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = search(key, binTrie);
|
||||
List<HanlpMapResult> hanlpMapResults = result.stream().map(
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
return new HanlpMapResult(name, entry.getValue(), key);
|
||||
}
|
||||
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toList());
|
||||
hanlpMapResults = transformAndFilterByDataSet(hanlpMapResults, modelIdToDataSetIds,
|
||||
detectDataSetIds, limit);
|
||||
List<HanlpMapResult> hanlpMapResults =
|
||||
result.stream()
|
||||
.map(
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
return new HanlpMapResult(name, entry.getValue(), key);
|
||||
})
|
||||
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toList());
|
||||
hanlpMapResults =
|
||||
transformAndFilterByDataSet(
|
||||
hanlpMapResults, modelIdToDataSetIds, detectDataSetIds, limit);
|
||||
return hanlpMapResults;
|
||||
}
|
||||
|
||||
/***
|
||||
* suffix Search
|
||||
/**
|
||||
* * suffix Search
|
||||
*
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
public static List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
public static List<HanlpMapResult> suffixSearch(
|
||||
String key,
|
||||
int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
String reverseDetectSegment = StringUtils.reverse(key);
|
||||
return suffixSearch(reverseDetectSegment, limit, suffixTrie, modelIdToDataSetIds, detectDataSetIds);
|
||||
return suffixSearch(
|
||||
reverseDetectSegment, limit, suffixTrie, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public static List<HanlpMapResult> suffixSearch(String key, int limit, BinTrie<List<String>> binTrie,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
public static List<HanlpMapResult> suffixSearch(
|
||||
String key,
|
||||
int limit,
|
||||
BinTrie<List<String>> binTrie,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = search(key, binTrie);
|
||||
List<HanlpMapResult> hanlpMapResults = result.stream().map(
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
List<String> natures = entry.getValue().stream()
|
||||
.map(nature -> nature.replaceAll(DictWordType.SUFFIX.getType(), ""))
|
||||
.collect(Collectors.toList());
|
||||
name = StringUtils.reverse(name);
|
||||
return new HanlpMapResult(name, natures, key);
|
||||
}
|
||||
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
List<HanlpMapResult> hanlpMapResults =
|
||||
result.stream()
|
||||
.map(
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
List<String> natures =
|
||||
entry.getValue().stream()
|
||||
.map(
|
||||
nature ->
|
||||
nature.replaceAll(
|
||||
DictWordType.SUFFIX
|
||||
.getType(),
|
||||
""))
|
||||
.collect(Collectors.toList());
|
||||
name = StringUtils.reverse(name);
|
||||
return new HanlpMapResult(name, natures, key);
|
||||
})
|
||||
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toList());
|
||||
return transformAndFilterByDataSet(
|
||||
hanlpMapResults, modelIdToDataSetIds, detectDataSetIds, limit);
|
||||
}
|
||||
|
||||
private static List<HanlpMapResult> transformAndFilterByDataSet(
|
||||
List<HanlpMapResult> hanlpMapResults,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds,
|
||||
int limit) {
|
||||
return hanlpMapResults.stream()
|
||||
.peek(
|
||||
hanlpMapResult -> {
|
||||
List<String> natures =
|
||||
hanlpMapResult.getNatures().stream()
|
||||
.map(
|
||||
nature ->
|
||||
NatureHelper.changeModel2DataSet(
|
||||
nature, modelIdToDataSetIds))
|
||||
.flatMap(Collection::stream)
|
||||
.filter(
|
||||
nature -> {
|
||||
if (CollectionUtils.isEmpty(
|
||||
detectDataSetIds)) {
|
||||
return true;
|
||||
}
|
||||
Long dataSetId =
|
||||
NatureHelper.getDataSetId(nature);
|
||||
if (dataSetId != null) {
|
||||
return detectDataSetIds.contains(
|
||||
dataSetId);
|
||||
}
|
||||
return false;
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
hanlpMapResult.setNatures(natures);
|
||||
})
|
||||
.filter(hanlpMapResult -> !CollectionUtils.isEmpty(hanlpMapResult.getNatures()))
|
||||
.limit(limit)
|
||||
.collect(Collectors.toList());
|
||||
return transformAndFilterByDataSet(hanlpMapResults, modelIdToDataSetIds, detectDataSetIds, limit);
|
||||
}
|
||||
|
||||
private static List<HanlpMapResult> transformAndFilterByDataSet(List<HanlpMapResult> hanlpMapResults,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds, int limit) {
|
||||
return hanlpMapResults.stream().peek(hanlpMapResult -> {
|
||||
List<String> natures = hanlpMapResult.getNatures().stream()
|
||||
.map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToDataSetIds))
|
||||
.flatMap(Collection::stream)
|
||||
.filter(nature -> {
|
||||
if (CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||
return true;
|
||||
}
|
||||
Long dataSetId = NatureHelper.getDataSetId(nature);
|
||||
if (dataSetId != null) {
|
||||
return detectDataSetIds.contains(dataSetId);
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
hanlpMapResult.setNatures(natures);
|
||||
}).filter(hanlpMapResult -> !CollectionUtils.isEmpty(hanlpMapResult.getNatures()))
|
||||
.limit(limit).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static Set<Map.Entry<String, List<String>>> search(String key,
|
||||
BinTrie<List<String>> binTrie) {
|
||||
private static Set<Map.Entry<String, List<String>>> search(
|
||||
String key, BinTrie<List<String>> binTrie) {
|
||||
key = key.toLowerCase();
|
||||
Set<Map.Entry<String, List<String>>> entrySet = new TreeSet<Map.Entry<String, List<String>>>();
|
||||
Set<Map.Entry<String, List<String>>> entrySet =
|
||||
new TreeSet<Map.Entry<String, List<String>>>();
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
if (StringUtils.isNotBlank(key)) {
|
||||
@@ -152,11 +197,14 @@ public class SearchService {
|
||||
}
|
||||
TreeMap<String, CoreDictionary.Attribute> map = new TreeMap();
|
||||
for (DictWord suffix : suffixes) {
|
||||
CoreDictionary.Attribute attributeNew = suffix.getNatureWithFrequency() == null
|
||||
? new CoreDictionary.Attribute(Nature.nz, 1)
|
||||
: CoreDictionary.Attribute.create(suffix.getNatureWithFrequency());
|
||||
CoreDictionary.Attribute attributeNew =
|
||||
suffix.getNatureWithFrequency() == null
|
||||
? new CoreDictionary.Attribute(Nature.nz, 1)
|
||||
: CoreDictionary.Attribute.create(suffix.getNatureWithFrequency());
|
||||
if (map.containsKey(suffix.getWord())) {
|
||||
attributeNew = DictionaryAttributeUtil.getAttribute(map.get(suffix.getWord()), attributeNew);
|
||||
attributeNew =
|
||||
DictionaryAttributeUtil.getAttribute(
|
||||
map.get(suffix.getWord()), attributeNew);
|
||||
}
|
||||
map.put(suffix.getWord(), attributeNew);
|
||||
}
|
||||
@@ -179,15 +227,18 @@ public class SearchService {
|
||||
if (Objects.nonNull(natures) && natures.length > 0) {
|
||||
trie.put(dictWord.getWord(), getValue(natures));
|
||||
}
|
||||
if (dictWord.getNature().contains(DictWordType.METRIC.getType()) || dictWord.getNature()
|
||||
.contains(DictWordType.DIMENSION.getType())) {
|
||||
if (dictWord.getNature().contains(DictWordType.METRIC.getType())
|
||||
|| dictWord.getNature().contains(DictWordType.DIMENSION.getType())) {
|
||||
suffixTrie.remove(dictWord.getWord());
|
||||
}
|
||||
}
|
||||
|
||||
public static List<String> getDimensionValue(DimensionValueReq dimensionValueReq) {
|
||||
String nature = DictWordType.NATURE_SPILT + dimensionValueReq.getModelId() + DictWordType.NATURE_SPILT
|
||||
+ dimensionValueReq.getElementID();
|
||||
String nature =
|
||||
DictWordType.NATURE_SPILT
|
||||
+ dimensionValueReq.getModelId()
|
||||
+ DictWordType.NATURE_SPILT
|
||||
+ dimensionValueReq.getElementID();
|
||||
PriorityQueue<Term> terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature);
|
||||
if (CollectionUtils.isEmpty(terms)) {
|
||||
return new ArrayList<>();
|
||||
@@ -195,4 +246,3 @@ public class SearchService {
|
||||
return terms.stream().map(term -> term.getWord()).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,14 +2,12 @@ package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
* base word nature
|
||||
*/
|
||||
/** base word nature */
|
||||
@Slf4j
|
||||
public abstract class BaseWordBuilder {
|
||||
|
||||
@@ -36,5 +34,4 @@ public abstract class BaseWordBuilder {
|
||||
}
|
||||
|
||||
protected abstract List<DictWord> doGet(String word, SchemaElement schemaElement);
|
||||
|
||||
}
|
||||
|
||||
@@ -2,14 +2,15 @@ package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
public abstract class BaseWordWithAliasBuilder extends BaseWordBuilder {
|
||||
|
||||
public abstract DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix);
|
||||
public abstract DictWord getOneWordNature(
|
||||
String word, SchemaElement schemaElement, boolean isSuffix);
|
||||
|
||||
public List<DictWord> getOneWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
|
||||
List<DictWord> dictWords = new ArrayList<>();
|
||||
@@ -22,5 +23,4 @@ public abstract class BaseWordWithAliasBuilder extends BaseWordBuilder {
|
||||
}
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,14 +4,12 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* dimension word nature
|
||||
*/
|
||||
import java.util.List;
|
||||
|
||||
/** dimension word nature */
|
||||
@Service
|
||||
public class DimensionWordBuilder extends BaseWordWithAliasBuilder {
|
||||
|
||||
@@ -31,14 +29,22 @@ public class DimensionWordBuilder extends BaseWordWithAliasBuilder {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.DIMENSION.getType();
|
||||
String nature =
|
||||
DictWordType.NATURE_SPILT
|
||||
+ modelId
|
||||
+ DictWordType.NATURE_SPILT
|
||||
+ schemaElement.getId()
|
||||
+ DictWordType.DIMENSION.getType();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType() + DictWordType.DIMENSION.getType();
|
||||
nature =
|
||||
DictWordType.NATURE_SPILT
|
||||
+ modelId
|
||||
+ DictWordType.NATURE_SPILT
|
||||
+ schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType()
|
||||
+ DictWordType.DIMENSION.getType();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -24,17 +23,19 @@ public class EntityWordBuilder extends BaseWordWithAliasBuilder {
|
||||
result.add(getOneWordNature(word, schemaElement, false));
|
||||
result.addAll(getOneWordNatureAlias(schemaElement, false));
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
String nature = DictWordType.NATURE_SPILT + schemaElement.getModel()
|
||||
+ DictWordType.NATURE_SPILT + schemaElement.getId() + DictWordType.ENTITY.getType();
|
||||
String nature =
|
||||
DictWordType.NATURE_SPILT
|
||||
+ schemaElement.getModel()
|
||||
+ DictWordType.NATURE_SPILT
|
||||
+ schemaElement.getId()
|
||||
+ DictWordType.ENTITY.getType();
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY * 2, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,14 +4,12 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* Metric DictWord
|
||||
*/
|
||||
import java.util.List;
|
||||
|
||||
/** Metric DictWord */
|
||||
@Service
|
||||
public class MetricWordBuilder extends BaseWordWithAliasBuilder {
|
||||
|
||||
@@ -31,14 +29,22 @@ public class MetricWordBuilder extends BaseWordWithAliasBuilder {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.METRIC.getType();
|
||||
String nature =
|
||||
DictWordType.NATURE_SPILT
|
||||
+ modelId
|
||||
+ DictWordType.NATURE_SPILT
|
||||
+ schemaElement.getId()
|
||||
+ DictWordType.METRIC.getType();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType() + DictWordType.METRIC.getType();
|
||||
nature =
|
||||
DictWordType.NATURE_SPILT
|
||||
+ modelId
|
||||
+ DictWordType.NATURE_SPILT
|
||||
+ schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType()
|
||||
+ DictWordType.METRIC.getType();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,15 +4,13 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* model word nature
|
||||
*/
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/** model word nature */
|
||||
@Service
|
||||
@Slf4j
|
||||
public class ModelWordBuilder extends BaseWordWithAliasBuilder {
|
||||
@@ -35,5 +33,4 @@ public class ModelWordBuilder extends BaseWordWithAliasBuilder {
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -9,9 +9,7 @@ import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Metric DictWord
|
||||
*/
|
||||
/** Metric DictWord */
|
||||
@Service
|
||||
public class TermWordBuilder extends BaseWordWithAliasBuilder {
|
||||
|
||||
@@ -31,14 +29,22 @@ public class TermWordBuilder extends BaseWordWithAliasBuilder {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
Long dataSet = schemaElement.getDataSetId();
|
||||
String nature = DictWordType.NATURE_SPILT + dataSet + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.TERM.getType();
|
||||
String nature =
|
||||
DictWordType.NATURE_SPILT
|
||||
+ dataSet
|
||||
+ DictWordType.NATURE_SPILT
|
||||
+ schemaElement.getId()
|
||||
+ DictWordType.TERM.getType();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + dataSet + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType() + DictWordType.TERM.getType();
|
||||
nature =
|
||||
DictWordType.NATURE_SPILT
|
||||
+ dataSet
|
||||
+ DictWordType.NATURE_SPILT
|
||||
+ schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType()
|
||||
+ DictWordType.TERM.getType();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -27,10 +26,13 @@ public class ValueWordBuilder extends BaseWordWithAliasBuilder {
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId();
|
||||
String nature =
|
||||
DictWordType.NATURE_SPILT
|
||||
+ modelId
|
||||
+ DictWordType.NATURE_SPILT
|
||||
+ schemaElement.getId();
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
dictWord.setWord(word);
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* DictWord Strategy Factory
|
||||
*/
|
||||
/** DictWord Strategy Factory */
|
||||
public class WordBuilderFactory {
|
||||
|
||||
private static Map<DictWordType, BaseWordBuilder> wordNatures = new ConcurrentHashMap<>();
|
||||
@@ -24,4 +22,4 @@ public class WordBuilderFactory {
|
||||
public static BaseWordBuilder get(DictWordType strategyType) {
|
||||
return wordNatures.get(strategyType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,13 +8,11 @@ import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import java.io.FileNotFoundException;
|
||||
|
||||
|
||||
@Data
|
||||
@Configuration
|
||||
@Slf4j
|
||||
public class ChatLocalFileConfig {
|
||||
|
||||
|
||||
@Value("${s2.dict.directory.latest:/data/dictionary/custom}")
|
||||
private String dictDirectoryLatest;
|
||||
|
||||
@@ -38,4 +36,4 @@ public class ChatLocalFileConfig {
|
||||
}
|
||||
return hanlpPropertiesPath;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,7 @@ import java.util.List;
|
||||
public interface FileHandler {
|
||||
|
||||
/**
|
||||
* backup files to a specific directory
|
||||
* config: dict.directory.backup
|
||||
* backup files to a specific directory config: dict.directory.backup
|
||||
*
|
||||
* @param fileName
|
||||
*/
|
||||
@@ -26,8 +25,7 @@ public interface FileHandler {
|
||||
Boolean existPath(String path);
|
||||
|
||||
/**
|
||||
* write data to a specific file,
|
||||
* config dir: dict.directory.latest
|
||||
* write data to a specific file, config dir: dict.directory.latest
|
||||
*
|
||||
* @param data
|
||||
* @param fileName
|
||||
@@ -43,8 +41,7 @@ public interface FileHandler {
|
||||
String getDictRootPath();
|
||||
|
||||
/**
|
||||
* delete dictionary file
|
||||
* automatic backup
|
||||
* delete dictionary file automatic backup
|
||||
*
|
||||
* @param fileName
|
||||
* @return
|
||||
|
||||
@@ -8,16 +8,15 @@ import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
@@ -52,7 +51,6 @@ public class FileHandlerImpl implements FileHandler {
|
||||
} catch (IOException e) {
|
||||
log.info("Failed to copy file: " + e.getMessage());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -83,8 +81,12 @@ public class FileHandlerImpl implements FileHandler {
|
||||
String filePath = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName;
|
||||
Long fileLineNum = getFileLineNum(filePath);
|
||||
Integer startLine = (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1;
|
||||
Integer endLine = Integer.valueOf(
|
||||
Math.min(dictValueReq.getCurrent() * dictValueReq.getPageSize(), fileLineNum) + "");
|
||||
Integer endLine =
|
||||
Integer.valueOf(
|
||||
Math.min(
|
||||
dictValueReq.getCurrent() * dictValueReq.getPageSize(),
|
||||
fileLineNum)
|
||||
+ "");
|
||||
List<DictValueResp> dictValueRespList = getFileData(filePath, startLine, endLine);
|
||||
|
||||
dictValueRespPageInfo.setPageSize(dictValueReq.getPageSize());
|
||||
@@ -110,17 +112,16 @@ public class FileHandlerImpl implements FileHandler {
|
||||
List<DictValueResp> fileData = new ArrayList<>();
|
||||
|
||||
try (Stream<String> lines = Files.lines(Paths.get(filePath))) {
|
||||
fileData = lines
|
||||
.skip(startLine - 1)
|
||||
.limit(endLine - startLine + 1)
|
||||
.map(lineStr -> convert2Resp(lineStr))
|
||||
.filter(line -> Objects.nonNull(line))
|
||||
.collect(Collectors.toList());
|
||||
fileData =
|
||||
lines.skip(startLine - 1)
|
||||
.limit(endLine - startLine + 1)
|
||||
.map(lineStr -> convert2Resp(lineStr))
|
||||
.filter(line -> Objects.nonNull(line))
|
||||
.collect(Collectors.toList());
|
||||
} catch (IOException e) {
|
||||
log.warn("[getFileData] e:{}", e);
|
||||
}
|
||||
return fileData;
|
||||
|
||||
}
|
||||
|
||||
private DictValueResp convert2Resp(String lineStr) {
|
||||
@@ -138,8 +139,7 @@ public class FileHandlerImpl implements FileHandler {
|
||||
|
||||
private Long getFileLineNum(String filePath) {
|
||||
try (Stream<String> lines = Files.lines(Paths.get(filePath))) {
|
||||
Long lineCount = lines
|
||||
.count();
|
||||
Long lineCount = lines.count();
|
||||
return lineCount;
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
@@ -204,8 +204,9 @@ public class FileHandlerImpl implements FileHandler {
|
||||
|
||||
private BufferedWriter getWriter(String filePath, Boolean append) throws IOException {
|
||||
if (append) {
|
||||
return Files.newBufferedWriter(Paths.get(filePath), StandardCharsets.UTF_8, StandardOpenOption.APPEND);
|
||||
return Files.newBufferedWriter(
|
||||
Paths.get(filePath), StandardCharsets.UTF_8, StandardOpenOption.APPEND);
|
||||
}
|
||||
return Files.newBufferedWriter(Paths.get(filePath), StandardCharsets.UTF_8);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import java.io.FileNotFoundException;
|
||||
@Slf4j
|
||||
public class LocalFileConfig {
|
||||
|
||||
|
||||
@Value("${s2.dict.directory.latest:/data/dictionary/custom}")
|
||||
private String dictDirectoryLatest;
|
||||
|
||||
@@ -36,5 +35,4 @@ public class LocalFileConfig {
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,14 +2,13 @@ package com.tencent.supersonic.headless.chat.knowledge.helper;
|
||||
|
||||
import com.hankcs.hanlp.HanLP.Config;
|
||||
import com.hankcs.hanlp.dictionary.DynamicCustomDictionary;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class FileHelper {
|
||||
|
||||
@@ -33,15 +32,17 @@ public class FileHelper {
|
||||
}
|
||||
|
||||
private static File[] getFileList(File customFolder, String suffix) {
|
||||
File[] customSubFiles = customFolder.listFiles(file -> {
|
||||
if (file.isDirectory()) {
|
||||
return false;
|
||||
}
|
||||
if (file.getName().toLowerCase().endsWith(suffix)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
File[] customSubFiles =
|
||||
customFolder.listFiles(
|
||||
file -> {
|
||||
if (file.isDirectory()) {
|
||||
return false;
|
||||
}
|
||||
if (file.getName().toLowerCase().endsWith(suffix)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return customSubFiles;
|
||||
}
|
||||
|
||||
@@ -72,8 +73,10 @@ public class FileHelper {
|
||||
|
||||
log.debug("CustomDictionaryPath:{}", fileList);
|
||||
Config.CustomDictionaryPath = fileList.toArray(new String[0]);
|
||||
customDictionary.path = (Config.CustomDictionaryPath == null || Config.CustomDictionaryPath.length == 0) ? path
|
||||
: Config.CustomDictionaryPath;
|
||||
customDictionary.path =
|
||||
(Config.CustomDictionaryPath == null || Config.CustomDictionaryPath.length == 0)
|
||||
? path
|
||||
: Config.CustomDictionaryPath;
|
||||
if (Config.CustomDictionaryPath == null || Config.CustomDictionaryPath.length == 0) {
|
||||
Config.CustomDictionaryPath = path;
|
||||
}
|
||||
|
||||
@@ -33,9 +33,7 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* HanLP helper
|
||||
*/
|
||||
/** HanLP helper */
|
||||
@Slf4j
|
||||
public class HanlpHelper {
|
||||
|
||||
@@ -57,14 +55,21 @@ public class HanlpHelper {
|
||||
if (segment == null) {
|
||||
synchronized (HanlpHelper.class) {
|
||||
if (segment == null) {
|
||||
segment = HanLP.newSegment()
|
||||
.enableIndexMode(true).enableIndexMode(4)
|
||||
.enableCustomDictionary(true).enableCustomDictionaryForcing(true).enableOffset(true)
|
||||
.enableJapaneseNameRecognize(false).enableNameRecognize(false)
|
||||
.enableAllNamedEntityRecognize(false)
|
||||
.enableJapaneseNameRecognize(false).enableNumberQuantifierRecognize(false)
|
||||
.enablePlaceRecognize(false)
|
||||
.enableOrganizationRecognize(false).enableCustomDictionary(getDynamicCustomDictionary());
|
||||
segment =
|
||||
HanLP.newSegment()
|
||||
.enableIndexMode(true)
|
||||
.enableIndexMode(4)
|
||||
.enableCustomDictionary(true)
|
||||
.enableCustomDictionaryForcing(true)
|
||||
.enableOffset(true)
|
||||
.enableJapaneseNameRecognize(false)
|
||||
.enableNameRecognize(false)
|
||||
.enableAllNamedEntityRecognize(false)
|
||||
.enableJapaneseNameRecognize(false)
|
||||
.enableNumberQuantifierRecognize(false)
|
||||
.enablePlaceRecognize(false)
|
||||
.enableOrganizationRecognize(false)
|
||||
.enableCustomDictionary(getDynamicCustomDictionary());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -82,14 +87,13 @@ public class HanlpHelper {
|
||||
return CustomDictionary;
|
||||
}
|
||||
|
||||
/***
|
||||
* reload custom dictionary
|
||||
*/
|
||||
/** * reload custom dictionary */
|
||||
public static boolean reloadCustomDictionary() throws IOException {
|
||||
|
||||
final long startTime = System.currentTimeMillis();
|
||||
|
||||
if (HanLP.Config.CustomDictionaryPath == null || HanLP.Config.CustomDictionaryPath.length == 0) {
|
||||
if (HanLP.Config.CustomDictionaryPath == null
|
||||
|| HanLP.Config.CustomDictionaryPath.length == 0) {
|
||||
return false;
|
||||
}
|
||||
if (HanLP.Config.IOAdapter instanceof HadoopFileIOAdapter) {
|
||||
@@ -106,7 +110,8 @@ public class HanlpHelper {
|
||||
|
||||
boolean reload = getDynamicCustomDictionary().reload();
|
||||
if (reload) {
|
||||
log.info("Custom dictionary has been reloaded in {} milliseconds",
|
||||
log.info(
|
||||
"Custom dictionary has been reloaded in {} milliseconds",
|
||||
System.currentTimeMillis() - startTime);
|
||||
}
|
||||
return reload;
|
||||
@@ -118,51 +123,74 @@ public class HanlpHelper {
|
||||
}
|
||||
String hanlpPropertiesPath = getHanlpPropertiesPath();
|
||||
|
||||
HanLP.Config.CustomDictionaryPath = Arrays.stream(HanLP.Config.CustomDictionaryPath)
|
||||
.map(path -> hanlpPropertiesPath + FILE_SPILT + path)
|
||||
.toArray(String[]::new);
|
||||
log.info("hanlpPropertiesPath:{},CustomDictionaryPath:{}", hanlpPropertiesPath,
|
||||
HanLP.Config.CustomDictionaryPath =
|
||||
Arrays.stream(HanLP.Config.CustomDictionaryPath)
|
||||
.map(path -> hanlpPropertiesPath + FILE_SPILT + path)
|
||||
.toArray(String[]::new);
|
||||
log.info(
|
||||
"hanlpPropertiesPath:{},CustomDictionaryPath:{}",
|
||||
hanlpPropertiesPath,
|
||||
HanLP.Config.CustomDictionaryPath);
|
||||
|
||||
HanLP.Config.CoreDictionaryPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.BiGramDictionaryPath;
|
||||
HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath;
|
||||
HanLP.Config.BiGramDictionaryPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.BiGramDictionaryPath;
|
||||
HanLP.Config.CoreDictionaryPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.BiGramDictionaryPath;
|
||||
HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath =
|
||||
hanlpPropertiesPath
|
||||
+ FILE_SPILT
|
||||
+ HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath;
|
||||
HanLP.Config.BiGramDictionaryPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.BiGramDictionaryPath;
|
||||
HanLP.Config.CoreStopWordDictionaryPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CoreStopWordDictionaryPath;
|
||||
HanLP.Config.CoreSynonymDictionaryDictionaryPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.CoreSynonymDictionaryDictionaryPath;
|
||||
HanLP.Config.PersonDictionaryPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PersonDictionaryPath;
|
||||
HanLP.Config.PersonDictionaryTrPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PersonDictionaryTrPath;
|
||||
HanLP.Config.CoreSynonymDictionaryDictionaryPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CoreSynonymDictionaryDictionaryPath;
|
||||
HanLP.Config.PersonDictionaryPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PersonDictionaryPath;
|
||||
HanLP.Config.PersonDictionaryTrPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PersonDictionaryTrPath;
|
||||
|
||||
HanLP.Config.PinyinDictionaryPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PinyinDictionaryPath;
|
||||
HanLP.Config.TranslatedPersonDictionaryPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.TranslatedPersonDictionaryPath;
|
||||
HanLP.Config.JapanesePersonDictionaryPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.JapanesePersonDictionaryPath;
|
||||
HanLP.Config.PlaceDictionaryPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PlaceDictionaryPath;
|
||||
HanLP.Config.PlaceDictionaryTrPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PlaceDictionaryTrPath;
|
||||
HanLP.Config.OrganizationDictionaryPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.OrganizationDictionaryPath;
|
||||
HanLP.Config.OrganizationDictionaryTrPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.OrganizationDictionaryTrPath;
|
||||
HanLP.Config.PinyinDictionaryPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PinyinDictionaryPath;
|
||||
HanLP.Config.TranslatedPersonDictionaryPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.TranslatedPersonDictionaryPath;
|
||||
HanLP.Config.JapanesePersonDictionaryPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.JapanesePersonDictionaryPath;
|
||||
HanLP.Config.PlaceDictionaryPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PlaceDictionaryPath;
|
||||
HanLP.Config.PlaceDictionaryTrPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PlaceDictionaryTrPath;
|
||||
HanLP.Config.OrganizationDictionaryPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.OrganizationDictionaryPath;
|
||||
HanLP.Config.OrganizationDictionaryTrPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.OrganizationDictionaryTrPath;
|
||||
HanLP.Config.CharTypePath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CharTypePath;
|
||||
HanLP.Config.CharTablePath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CharTablePath;
|
||||
HanLP.Config.PartOfSpeechTagDictionary =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PartOfSpeechTagDictionary;
|
||||
HanLP.Config.WordNatureModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.WordNatureModelPath;
|
||||
HanLP.Config.MaxEntModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.MaxEntModelPath;
|
||||
HanLP.Config.NNParserModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.NNParserModelPath;
|
||||
HanLP.Config.WordNatureModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.WordNatureModelPath;
|
||||
HanLP.Config.MaxEntModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.MaxEntModelPath;
|
||||
HanLP.Config.NNParserModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.NNParserModelPath;
|
||||
HanLP.Config.PerceptronParserModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PerceptronParserModelPath;
|
||||
HanLP.Config.CRFSegmentModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CRFSegmentModelPath;
|
||||
HanLP.Config.HMMSegmentModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.HMMSegmentModelPath;
|
||||
HanLP.Config.CRFCWSModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CRFCWSModelPath;
|
||||
HanLP.Config.CRFPOSModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CRFPOSModelPath;
|
||||
HanLP.Config.CRFNERModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CRFNERModelPath;
|
||||
HanLP.Config.PerceptronCWSModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PerceptronCWSModelPath;
|
||||
HanLP.Config.PerceptronPOSModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PerceptronPOSModelPath;
|
||||
HanLP.Config.PerceptronNERModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PerceptronNERModelPath;
|
||||
HanLP.Config.CRFSegmentModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CRFSegmentModelPath;
|
||||
HanLP.Config.HMMSegmentModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.HMMSegmentModelPath;
|
||||
HanLP.Config.CRFCWSModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CRFCWSModelPath;
|
||||
HanLP.Config.CRFPOSModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CRFPOSModelPath;
|
||||
HanLP.Config.CRFNERModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CRFNERModelPath;
|
||||
HanLP.Config.PerceptronCWSModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PerceptronCWSModelPath;
|
||||
HanLP.Config.PerceptronPOSModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PerceptronPOSModelPath;
|
||||
HanLP.Config.PerceptronNERModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PerceptronNERModelPath;
|
||||
}
|
||||
|
||||
public static String getHanlpPropertiesPath() throws FileNotFoundException {
|
||||
@@ -171,7 +199,8 @@ public class HanlpHelper {
|
||||
|
||||
public static boolean addToCustomDictionary(DictWord dictWord) {
|
||||
log.debug("dictWord:{}", dictWord);
|
||||
return getDynamicCustomDictionary().insert(dictWord.getWord(), dictWord.getNatureWithFrequency());
|
||||
return getDynamicCustomDictionary()
|
||||
.insert(dictWord.getWord(), dictWord.getNatureWithFrequency());
|
||||
}
|
||||
|
||||
public static void removeFromCustomDictionary(DictWord dictWord) {
|
||||
@@ -195,7 +224,8 @@ public class HanlpHelper {
|
||||
int len = natureWithFrequency.length();
|
||||
log.info("filtered natureWithFrequency:{}", natureWithFrequency);
|
||||
if (StringUtils.isNotBlank(natureWithFrequency)) {
|
||||
getDynamicCustomDictionary().add(dictWord.getWord(), natureWithFrequency.substring(0, len - 1));
|
||||
getDynamicCustomDictionary()
|
||||
.add(dictWord.getWord(), natureWithFrequency.substring(0, len - 1));
|
||||
}
|
||||
SearchService.remove(dictWord, natureList.toArray(new Nature[0]));
|
||||
}
|
||||
@@ -225,8 +255,8 @@ public class HanlpHelper {
|
||||
mapResults.addAll(newResults);
|
||||
}
|
||||
|
||||
public static <T extends MapResult> boolean addLetterOriginal(List<T> mapResults, T mapResult,
|
||||
CoreDictionary.Attribute attribute) {
|
||||
public static <T extends MapResult> boolean addLetterOriginal(
|
||||
List<T> mapResults, T mapResult, CoreDictionary.Attribute attribute) {
|
||||
if (attribute == null) {
|
||||
return false;
|
||||
}
|
||||
@@ -236,8 +266,9 @@ public class HanlpHelper {
|
||||
for (String nature : hanlpMapResult.getNatures()) {
|
||||
String orig = attribute.getOriginal(Nature.fromString(nature));
|
||||
if (orig != null) {
|
||||
MapResult addMapResult = new HanlpMapResult(
|
||||
orig, Arrays.asList(nature), hanlpMapResult.getDetectWord());
|
||||
MapResult addMapResult =
|
||||
new HanlpMapResult(
|
||||
orig, Arrays.asList(nature), hanlpMapResult.getDetectWord());
|
||||
mapResults.add((T) addMapResult);
|
||||
isAdd = true;
|
||||
}
|
||||
@@ -285,9 +316,12 @@ public class HanlpHelper {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static List<S2Term> transform2ApiTerm(Term term, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
public static List<S2Term> transform2ApiTerm(
|
||||
Term term, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
List<S2Term> s2Terms = Lists.newArrayList();
|
||||
List<String> natures = NatureHelper.changeModel2DataSet(String.valueOf(term.getNature()), modelIdToDataSetIds);
|
||||
List<String> natures =
|
||||
NatureHelper.changeModel2DataSet(
|
||||
String.valueOf(term.getNature()), modelIdToDataSetIds);
|
||||
for (String nature : natures) {
|
||||
S2Term s2Term = new S2Term();
|
||||
BeanUtils.copyProperties(term, s2Term);
|
||||
@@ -297,5 +331,4 @@ public class HanlpHelper {
|
||||
}
|
||||
return s2Terms;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,25 +3,24 @@ package com.tencent.supersonic.headless.chat.knowledge.helper;
|
||||
import com.hankcs.hanlp.HanLP.Config;
|
||||
import com.hankcs.hanlp.dictionary.DynamicCustomDictionary;
|
||||
import com.hankcs.hanlp.utility.Predefine;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FileStatus;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
/**
|
||||
* Hdfs File Helper
|
||||
*/
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/** Hdfs File Helper */
|
||||
@Slf4j
|
||||
public class HdfsFileHelper {
|
||||
|
||||
/***
|
||||
* delete cache file
|
||||
/**
|
||||
* * delete cache file
|
||||
*
|
||||
* @param path
|
||||
* @throws IOException
|
||||
*/
|
||||
@@ -35,7 +34,8 @@ public class HdfsFileHelper {
|
||||
log.error("delete:" + cacheFilePath, e);
|
||||
}
|
||||
int customBase = cacheFilePath.lastIndexOf(FileHelper.FILE_SPILT);
|
||||
String customPath = cacheFilePath.substring(0, customBase) + FileHelper.FILE_SPILT + "*.bin";
|
||||
String customPath =
|
||||
cacheFilePath.substring(0, customBase) + FileHelper.FILE_SPILT + "*.bin";
|
||||
List<String> fileList = getFileList(fs, new Path(customPath));
|
||||
for (String file : fileList) {
|
||||
try {
|
||||
@@ -54,18 +54,22 @@ public class HdfsFileHelper {
|
||||
* @param customDictionary
|
||||
* @throws IOException
|
||||
*/
|
||||
public static void resetCustomPath(DynamicCustomDictionary customDictionary) throws IOException {
|
||||
public static void resetCustomPath(DynamicCustomDictionary customDictionary)
|
||||
throws IOException {
|
||||
String[] path = Config.CustomDictionaryPath;
|
||||
FileSystem fs = FileSystem.get(URI.create(path[0]), new Configuration());
|
||||
String cacheFilePath = path[0] + Predefine.BIN_EXT;
|
||||
int customBase = cacheFilePath.lastIndexOf(FileHelper.FILE_SPILT);
|
||||
String customPath = cacheFilePath.substring(0, customBase) + FileHelper.FILE_SPILT + "*.txt";
|
||||
String customPath =
|
||||
cacheFilePath.substring(0, customBase) + FileHelper.FILE_SPILT + "*.txt";
|
||||
log.info("customPath:{}", customPath);
|
||||
List<String> fileList = getFileList(fs, new Path(customPath));
|
||||
log.info("CustomDictionaryPath:{}", fileList);
|
||||
Config.CustomDictionaryPath = fileList.toArray(new String[0]);
|
||||
customDictionary.path = (Config.CustomDictionaryPath == null || Config.CustomDictionaryPath.length == 0) ? path
|
||||
: Config.CustomDictionaryPath;
|
||||
customDictionary.path =
|
||||
(Config.CustomDictionaryPath == null || Config.CustomDictionaryPath.length == 0)
|
||||
? path
|
||||
: Config.CustomDictionaryPath;
|
||||
if (Config.CustomDictionaryPath == null || Config.CustomDictionaryPath.length == 0) {
|
||||
Config.CustomDictionaryPath = path;
|
||||
}
|
||||
|
||||
@@ -5,6 +5,10 @@ import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DataSetInfoStat;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -13,13 +17,8 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* nature parse helper
|
||||
*/
|
||||
/** nature parse helper */
|
||||
@Slf4j
|
||||
public class NatureHelper {
|
||||
|
||||
@@ -56,8 +55,8 @@ public class NatureHelper {
|
||||
|
||||
private static boolean isDataSetOrEntity(S2Term term, Integer model) {
|
||||
String natureStr = term.nature.toString();
|
||||
return (DictWordType.NATURE_SPILT + model).equals(natureStr) || natureStr.endsWith(
|
||||
DictWordType.ENTITY.getType());
|
||||
return (DictWordType.NATURE_SPILT + model).equals(natureStr)
|
||||
|| natureStr.endsWith(DictWordType.ENTITY.getType());
|
||||
}
|
||||
|
||||
public static Integer getDataSetByNature(Nature nature) {
|
||||
@@ -90,7 +89,8 @@ public class NatureHelper {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static List<String> changeModel2DataSet(String nature, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
public static List<String> changeModel2DataSet(
|
||||
String nature, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
if (SchemaElementType.TERM.equals(NatureHelper.convertToElementType(nature))) {
|
||||
return Collections.singletonList(nature);
|
||||
}
|
||||
@@ -107,8 +107,10 @@ public class NatureHelper {
|
||||
}
|
||||
|
||||
public static boolean isDimensionValueDataSetId(String nature) {
|
||||
return isNatureValid(nature) && !isNatureType(nature, DictWordType.METRIC, DictWordType.DIMENSION,
|
||||
DictWordType.TERM) && StringUtils.isNumeric(nature.split(DictWordType.NATURE_SPILT)[1]);
|
||||
return isNatureValid(nature)
|
||||
&& !isNatureType(
|
||||
nature, DictWordType.METRIC, DictWordType.DIMENSION, DictWordType.TERM)
|
||||
&& StringUtils.isNumeric(nature.split(DictWordType.NATURE_SPILT)[1]);
|
||||
}
|
||||
|
||||
public static boolean isTermNature(String nature) {
|
||||
@@ -125,34 +127,53 @@ public class NatureHelper {
|
||||
}
|
||||
|
||||
private static long getDataSetCount(List<S2Term> terms) {
|
||||
return terms.stream().filter(term -> isDataSetOrEntity(term, getDataSetByNature(term.nature))).count();
|
||||
return terms.stream()
|
||||
.filter(term -> isDataSetOrEntity(term, getDataSetByNature(term.nature)))
|
||||
.count();
|
||||
}
|
||||
|
||||
private static long getDimensionValueCount(List<S2Term> terms) {
|
||||
return terms.stream().filter(term -> isDimensionValueDataSetId(term.nature.toString())).count();
|
||||
return terms.stream()
|
||||
.filter(term -> isDimensionValueDataSetId(term.nature.toString()))
|
||||
.count();
|
||||
}
|
||||
|
||||
private static long getDimensionCount(List<S2Term> terms) {
|
||||
return terms.stream().filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT) && term.nature.toString()
|
||||
.endsWith(DictWordType.DIMENSION.getType())).count();
|
||||
return terms.stream()
|
||||
.filter(
|
||||
term ->
|
||||
term.nature.startsWith(DictWordType.NATURE_SPILT)
|
||||
&& term.nature
|
||||
.toString()
|
||||
.endsWith(DictWordType.DIMENSION.getType()))
|
||||
.count();
|
||||
}
|
||||
|
||||
private static long getMetricCount(List<S2Term> terms) {
|
||||
return terms.stream().filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT) && term.nature.toString()
|
||||
.endsWith(DictWordType.METRIC.getType())).count();
|
||||
return terms.stream()
|
||||
.filter(
|
||||
term ->
|
||||
term.nature.startsWith(DictWordType.NATURE_SPILT)
|
||||
&& term.nature
|
||||
.toString()
|
||||
.endsWith(DictWordType.METRIC.getType()))
|
||||
.count();
|
||||
}
|
||||
|
||||
public static Map<Long, Map<DictWordType, Integer>> getDataSetToNatureStat(List<S2Term> terms) {
|
||||
Map<Long, Map<DictWordType, Integer>> modelToNature = new HashMap<>();
|
||||
terms.stream()
|
||||
.filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT))
|
||||
.forEach(term -> {
|
||||
DictWordType dictWordType = DictWordType.getNatureType(term.nature.toString());
|
||||
Long model = getDataSetId(term.nature.toString());
|
||||
.forEach(
|
||||
term -> {
|
||||
DictWordType dictWordType =
|
||||
DictWordType.getNatureType(term.nature.toString());
|
||||
Long model = getDataSetId(term.nature.toString());
|
||||
|
||||
modelToNature.computeIfAbsent(model, k -> new HashMap<>())
|
||||
.merge(dictWordType, 1, Integer::sum);
|
||||
});
|
||||
modelToNature
|
||||
.computeIfAbsent(model, k -> new HashMap<>())
|
||||
.merge(dictWordType, 1, Integer::sum);
|
||||
});
|
||||
return modelToNature;
|
||||
}
|
||||
|
||||
@@ -160,10 +181,12 @@ public class NatureHelper {
|
||||
Map<Long, Map<DictWordType, Integer>> modelToNatureStat = getDataSetToNatureStat(terms);
|
||||
return modelToNatureStat.entrySet().stream()
|
||||
.max(Comparator.comparingInt(entry -> entry.getValue().size()))
|
||||
.map(entry -> modelToNatureStat.entrySet().stream()
|
||||
.filter(e -> e.getValue().size() == entry.getValue().size())
|
||||
.map(Map.Entry::getKey)
|
||||
.collect(Collectors.toList()))
|
||||
.map(
|
||||
entry ->
|
||||
modelToNatureStat.entrySet().stream()
|
||||
.filter(e -> e.getValue().size() == entry.getValue().size())
|
||||
.map(Map.Entry::getKey)
|
||||
.collect(Collectors.toList()))
|
||||
.orElse(Collections.emptyList());
|
||||
}
|
||||
|
||||
@@ -171,7 +194,8 @@ public class NatureHelper {
|
||||
return parseIdFromNature(nature, 2);
|
||||
}
|
||||
|
||||
public static Set<Long> getModelIds(Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
public static Set<Long> getModelIds(
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
if (CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||
return modelIdToDataSetIds.keySet();
|
||||
}
|
||||
|
||||
@@ -13,11 +13,11 @@ import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.HashSet;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -30,7 +30,9 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
|
||||
String simpleName = this.getClass().getSimpleName();
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.debug("before {},mapInfo:{}", simpleName,
|
||||
log.debug(
|
||||
"before {},mapInfo:{}",
|
||||
simpleName,
|
||||
chatQueryContext.getMapInfo().getDataSetElementMatches());
|
||||
|
||||
try {
|
||||
@@ -41,7 +43,10 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
}
|
||||
|
||||
long cost = System.currentTimeMillis() - startTime;
|
||||
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost,
|
||||
log.debug(
|
||||
"after {},cost:{},mapInfo:{}",
|
||||
simpleName,
|
||||
cost,
|
||||
chatQueryContext.getMapInfo().getDataSetElementMatches());
|
||||
}
|
||||
|
||||
@@ -53,14 +58,19 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0));
|
||||
break;
|
||||
case METRIC:
|
||||
filterByQueryDataType(chatQueryContext, element -> !SchemaElementType.METRIC.equals(element.getType()));
|
||||
filterByQueryDataType(
|
||||
chatQueryContext,
|
||||
element -> !SchemaElementType.METRIC.equals(element.getType()));
|
||||
break;
|
||||
case DIMENSION:
|
||||
filterByQueryDataType(chatQueryContext, element -> {
|
||||
boolean isDimensionOrValue = SchemaElementType.DIMENSION.equals(element.getType())
|
||||
|| SchemaElementType.VALUE.equals(element.getType());
|
||||
return !isDimensionOrValue;
|
||||
});
|
||||
filterByQueryDataType(
|
||||
chatQueryContext,
|
||||
element -> {
|
||||
boolean isDimensionOrValue =
|
||||
SchemaElementType.DIMENSION.equals(element.getType())
|
||||
|| SchemaElementType.VALUE.equals(element.getType());
|
||||
return !isDimensionOrValue;
|
||||
});
|
||||
break;
|
||||
case ALL:
|
||||
default:
|
||||
@@ -73,7 +83,8 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return;
|
||||
}
|
||||
Set<Long> dataSetIdInMapInfo = new HashSet<>(chatQueryContext.getMapInfo().getDataSetElementMatches().keySet());
|
||||
Set<Long> dataSetIdInMapInfo =
|
||||
new HashSet<>(chatQueryContext.getMapInfo().getDataSetElementMatches().keySet());
|
||||
for (Long dataSetId : dataSetIdInMapInfo) {
|
||||
if (!dataSetIds.contains(dataSetId)) {
|
||||
chatQueryContext.getMapInfo().getDataSetElementMatches().remove(dataSetId);
|
||||
@@ -87,54 +98,68 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
|
||||
List<SchemaElementMatch> value = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(value)) {
|
||||
value.removeIf(schemaElementMatch ->
|
||||
StringUtils.length(schemaElementMatch.getDetectWord()) <= 1);
|
||||
value.removeIf(
|
||||
schemaElementMatch ->
|
||||
StringUtils.length(schemaElementMatch.getDetectWord()) <= 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void filterByQueryDataType(ChatQueryContext chatQueryContext,
|
||||
Predicate<SchemaElement> needRemovePredicate) {
|
||||
chatQueryContext.getMapInfo().getDataSetElementMatches().values().forEach(schemaElementMatches -> {
|
||||
schemaElementMatches.removeIf(schemaElementMatch -> {
|
||||
SchemaElement element = schemaElementMatch.getElement();
|
||||
SchemaElementType type = element.getType();
|
||||
private static void filterByQueryDataType(
|
||||
ChatQueryContext chatQueryContext, Predicate<SchemaElement> needRemovePredicate) {
|
||||
chatQueryContext
|
||||
.getMapInfo()
|
||||
.getDataSetElementMatches()
|
||||
.values()
|
||||
.forEach(
|
||||
schemaElementMatches -> {
|
||||
schemaElementMatches.removeIf(
|
||||
schemaElementMatch -> {
|
||||
SchemaElement element = schemaElementMatch.getElement();
|
||||
SchemaElementType type = element.getType();
|
||||
|
||||
boolean isEntityOrDatasetOrId = SchemaElementType.ENTITY.equals(type)
|
||||
|| SchemaElementType.DATASET.equals(type)
|
||||
|| SchemaElementType.ID.equals(type);
|
||||
boolean isEntityOrDatasetOrId =
|
||||
SchemaElementType.ENTITY.equals(type)
|
||||
|| SchemaElementType.DATASET.equals(type)
|
||||
|| SchemaElementType.ID.equals(type);
|
||||
|
||||
return !isEntityOrDatasetOrId && needRemovePredicate.test(element);
|
||||
});
|
||||
});
|
||||
return !isEntityOrDatasetOrId
|
||||
&& needRemovePredicate.test(element);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
public abstract void doMap(ChatQueryContext chatQueryContext);
|
||||
|
||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId, SchemaElementMatch newElementMatch) {
|
||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = schemaMap.getDataSetElementMatches();
|
||||
List<SchemaElementMatch> schemaElementMatches = dataSetElementMatches.computeIfAbsent(dataSetId,
|
||||
k -> new ArrayList<>());
|
||||
public void addToSchemaMap(
|
||||
SchemaMapInfo schemaMap, Long dataSetId, SchemaElementMatch newElementMatch) {
|
||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
||||
schemaMap.getDataSetElementMatches();
|
||||
List<SchemaElementMatch> schemaElementMatches =
|
||||
dataSetElementMatches.computeIfAbsent(dataSetId, k -> new ArrayList<>());
|
||||
|
||||
AtomicBoolean shouldAddNew = new AtomicBoolean(true);
|
||||
|
||||
schemaElementMatches.removeIf(existingElementMatch -> {
|
||||
if (isEquals(existingElementMatch, newElementMatch)) {
|
||||
if (newElementMatch.getSimilarity() > existingElementMatch.getSimilarity()) {
|
||||
return true;
|
||||
} else {
|
||||
shouldAddNew.set(false);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
});
|
||||
schemaElementMatches.removeIf(
|
||||
existingElementMatch -> {
|
||||
if (isEquals(existingElementMatch, newElementMatch)) {
|
||||
if (newElementMatch.getSimilarity()
|
||||
> existingElementMatch.getSimilarity()) {
|
||||
return true;
|
||||
} else {
|
||||
shouldAddNew.set(false);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
if (shouldAddNew.get()) {
|
||||
schemaElementMatches.add(newElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean isEquals(SchemaElementMatch existElementMatch, SchemaElementMatch newElementMatch) {
|
||||
private static boolean isEquals(
|
||||
SchemaElementMatch existElementMatch, SchemaElementMatch newElementMatch) {
|
||||
SchemaElement existElement = existElementMatch.getElement();
|
||||
SchemaElement newElement = newElementMatch.getElement();
|
||||
if (!existElement.equals(newElement)) {
|
||||
@@ -146,8 +171,11 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
return true;
|
||||
}
|
||||
|
||||
public SchemaElement getSchemaElement(Long dataSetId, SchemaElementType elementType, Long elementID,
|
||||
SemanticSchema semanticSchema) {
|
||||
public SchemaElement getSchemaElement(
|
||||
Long dataSetId,
|
||||
SchemaElementType elementType,
|
||||
Long elementID,
|
||||
SemanticSchema semanticSchema) {
|
||||
SchemaElement element = new SchemaElement();
|
||||
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
||||
if (Objects.isNull(dataSetSchema)) {
|
||||
@@ -166,8 +194,8 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
||||
return element.getAlias();
|
||||
}
|
||||
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(
|
||||
element.getName())) {
|
||||
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias())
|
||||
&& StringUtils.isNotEmpty(element.getName())) {
|
||||
return element.getAlias().stream()
|
||||
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
@@ -26,15 +25,13 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
@Autowired
|
||||
protected MapperHelper mapperHelper;
|
||||
@Autowired protected MapperHelper mapperHelper;
|
||||
|
||||
@Autowired
|
||||
protected MapperConfig mapperConfig;
|
||||
@Autowired protected MapperConfig mapperConfig;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
public Map<MatchText, List<T>> match(
|
||||
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||
String text = chatQueryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
@@ -49,7 +46,8 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
return result;
|
||||
}
|
||||
|
||||
public List<T> detect(ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||
public List<T> detect(
|
||||
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||
String text = chatQueryContext.getQueryText();
|
||||
Set<T> results = new HashSet<>();
|
||||
@@ -64,7 +62,8 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
if (index <= text.length()) {
|
||||
String detectSegment = text.substring(startIndex, index).trim();
|
||||
detectSegments.add(detectSegment);
|
||||
detectByStep(chatQueryContext, results, detectDataSetIds, detectSegment, offset);
|
||||
detectByStep(
|
||||
chatQueryContext, results, detectDataSetIds, detectSegment, offset);
|
||||
}
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
@@ -73,9 +72,13 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
}
|
||||
|
||||
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
|
||||
return terms.stream().sorted(Comparator.comparing(S2Term::length))
|
||||
.collect(Collectors.toMap(S2Term::getOffset, term -> term.word.length(),
|
||||
(value1, value2) -> value2));
|
||||
return terms.stream()
|
||||
.sorted(Comparator.comparing(S2Term::length))
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
S2Term::getOffset,
|
||||
term -> term.word.length(),
|
||||
(value1, value2) -> value2));
|
||||
}
|
||||
|
||||
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
||||
@@ -84,15 +87,15 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
}
|
||||
for (T oneRoundResult : oneRoundResults) {
|
||||
if (existResults.contains(oneRoundResult)) {
|
||||
boolean isDeleted = existResults.removeIf(
|
||||
existResult -> {
|
||||
boolean delete = needDelete(oneRoundResult, existResult);
|
||||
if (delete) {
|
||||
log.info("deleted existResult:{}", existResult);
|
||||
}
|
||||
return delete;
|
||||
}
|
||||
);
|
||||
boolean isDeleted =
|
||||
existResults.removeIf(
|
||||
existResult -> {
|
||||
boolean delete = needDelete(oneRoundResult, existResult);
|
||||
if (delete) {
|
||||
log.info("deleted existResult:{}", existResult);
|
||||
}
|
||||
return delete;
|
||||
});
|
||||
if (isDeleted) {
|
||||
log.info("deleted, add oneRoundResult:{}", oneRoundResult);
|
||||
existResults.add(oneRoundResult);
|
||||
@@ -111,9 +114,11 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
if (Objects.isNull(matchResult)) {
|
||||
return matches;
|
||||
}
|
||||
Optional<List<T>> first = matchResult.entrySet().stream()
|
||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||
.map(entry -> entry.getValue()).findFirst();
|
||||
Optional<List<T>> first =
|
||||
matchResult.entrySet().stream()
|
||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||
.map(entry -> entry.getValue())
|
||||
.findFirst();
|
||||
|
||||
if (first.isPresent()) {
|
||||
matches = first.get();
|
||||
@@ -124,13 +129,19 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
public List<S2Term> filterByDataSetId(List<S2Term> terms, Set<Long> dataSetIds) {
|
||||
logTerms(terms);
|
||||
if (CollectionUtils.isNotEmpty(dataSetIds)) {
|
||||
terms = terms.stream().filter(term -> {
|
||||
Long dataSetId = NatureHelper.getDataSetId(term.getNature().toString());
|
||||
if (Objects.nonNull(dataSetId)) {
|
||||
return dataSetIds.contains(dataSetId);
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
terms =
|
||||
terms.stream()
|
||||
.filter(
|
||||
term -> {
|
||||
Long dataSetId =
|
||||
NatureHelper.getDataSetId(
|
||||
term.getNature().toString());
|
||||
if (Objects.nonNull(dataSetId)) {
|
||||
return dataSetIds.contains(dataSetId);
|
||||
}
|
||||
return false;
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
log.debug("terms filter by dataSetId:{}", dataSetIds);
|
||||
logTerms(terms);
|
||||
}
|
||||
@@ -142,7 +153,11 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
return;
|
||||
}
|
||||
for (S2Term term : terms) {
|
||||
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||
log.debug(
|
||||
"word:{},nature:{},frequency:{}",
|
||||
term.word,
|
||||
term.nature.toString(),
|
||||
term.getFrequency());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,8 +165,12 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
public abstract String getMapKey(T a);
|
||||
|
||||
public abstract void detectByStep(ChatQueryContext chatQueryContext, Set<T> existResults,
|
||||
Set<Long> detectDataSetIds, String detectSegment, int offset);
|
||||
public abstract void detectByStep(
|
||||
ChatQueryContext chatQueryContext,
|
||||
Set<T> existResults,
|
||||
Set<Long> detectDataSetIds,
|
||||
String detectSegment,
|
||||
int offset);
|
||||
|
||||
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
||||
double decreaseAmount = (threshold - minThreshold) / 4;
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
@@ -21,8 +20,8 @@ import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
||||
* It currently supports fuzzy matching against names and aliases.
|
||||
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements. It currently supports
|
||||
* fuzzy matching against names and aliases.
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
@@ -31,8 +30,8 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
private List<SchemaElement> allElements;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<DatabaseMapResult>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
public Map<MatchText, List<DatabaseMapResult>> match(
|
||||
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||
this.allElements = getSchemaElements(chatQueryContext);
|
||||
return super.match(chatQueryContext, terms, detectDataSetIds);
|
||||
}
|
||||
@@ -45,12 +44,19 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
|
||||
@Override
|
||||
public String getMapKey(DatabaseMapResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + a.getSchemaElement().getId()
|
||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||
return a.getName()
|
||||
+ Constants.UNDERLINE
|
||||
+ a.getSchemaElement().getId()
|
||||
+ Constants.UNDERLINE
|
||||
+ a.getSchemaElement().getName();
|
||||
}
|
||||
|
||||
public void detectByStep(ChatQueryContext chatQueryContext, Set<DatabaseMapResult> existResults,
|
||||
Set<Long> detectDataSetIds, String detectSegment, int offset) {
|
||||
public void detectByStep(
|
||||
ChatQueryContext chatQueryContext,
|
||||
Set<DatabaseMapResult> existResults,
|
||||
Set<Long> detectDataSetIds,
|
||||
String detectSegment,
|
||||
int offset) {
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
}
|
||||
@@ -61,14 +67,19 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||
String name = entry.getKey();
|
||||
if (!name.contains(detectSegment)
|
||||
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
|
||||
|| mapperHelper.getSimilarity(detectSegment, name)
|
||||
< metricDimensionThresholdConfig) {
|
||||
continue;
|
||||
}
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> detectDataSetIds.contains(schemaElement.getDataSetId()))
|
||||
.collect(Collectors.toSet());
|
||||
schemaElements =
|
||||
schemaElements.stream()
|
||||
.filter(
|
||||
schemaElement ->
|
||||
detectDataSetIds.contains(
|
||||
schemaElement.getDataSetId()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
DatabaseMapResult databaseMapResult = new DatabaseMapResult();
|
||||
@@ -88,31 +99,42 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
||||
}
|
||||
|
||||
private Double getThreshold(ChatQueryContext chatQueryContext) {
|
||||
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
||||
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
|
||||
Double threshold =
|
||||
Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
||||
Double minThreshold =
|
||||
Double.valueOf(
|
||||
mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = chatQueryContext.getMapInfo()
|
||||
.getDataSetElementMatches();
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches =
|
||||
chatQueryContext.getMapInfo().getDataSetElementMatches();
|
||||
|
||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
boolean existElement =
|
||||
modelElementMatches.entrySet().stream()
|
||||
.anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
if (!existElement) {
|
||||
threshold = threshold / 2;
|
||||
log.debug("ModelElementMatches:{},not exist Element threshold reduce by half:{}",
|
||||
modelElementMatches, threshold);
|
||||
log.debug(
|
||||
"ModelElementMatches:{},not exist Element threshold reduce by half:{}",
|
||||
modelElementMatches,
|
||||
threshold);
|
||||
}
|
||||
return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum());
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||
return models.stream().collect(
|
||||
Collectors.toMap(SchemaElement::getName, a -> {
|
||||
Set<SchemaElement> result = new HashSet<>();
|
||||
result.add(a);
|
||||
return result;
|
||||
}, (k1, k2) -> {
|
||||
k1.addAll(k2);
|
||||
return k1;
|
||||
}));
|
||||
return models.stream()
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
SchemaElement::getName,
|
||||
a -> {
|
||||
Set<SchemaElement> result = new HashSet<>();
|
||||
result.add(a);
|
||||
return result;
|
||||
},
|
||||
(k1, k2) -> {
|
||||
k1.addAll(k2);
|
||||
return k1;
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,59 +1,63 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/***
|
||||
* A mapper that recognizes schema elements with vector embedding.
|
||||
*/
|
||||
/** * A mapper that recognizes schema elements with vector embedding. */
|
||||
@Slf4j
|
||||
public class EmbeddingMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
//1. query from embedding by queryText
|
||||
// 1. query from embedding by queryText
|
||||
String queryText = chatQueryContext.getQueryText();
|
||||
List<S2Term> terms = HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
||||
List<S2Term> terms =
|
||||
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
||||
|
||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||
List<EmbeddingResult> matchResults = matchStrategy.getMatches(chatQueryContext, terms);
|
||||
|
||||
HanlpHelper.transLetterOriginal(matchResults);
|
||||
|
||||
//2. build SchemaElementMatch by info
|
||||
// 2. build SchemaElementMatch by info
|
||||
for (EmbeddingResult matchResult : matchResults) {
|
||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||
Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId"));
|
||||
if (Objects.isNull(dataSetId)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
|
||||
chatQueryContext.getSemanticSchema());
|
||||
SchemaElementType elementType =
|
||||
SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||
SchemaElement schemaElement =
|
||||
getSchemaElement(
|
||||
dataSetId,
|
||||
elementType,
|
||||
elementId,
|
||||
chatQueryContext.getSemanticSchema());
|
||||
if (schemaElement == null) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.word(matchResult.getName())
|
||||
.similarity(1 - matchResult.getDistance())
|
||||
.detectWord(matchResult.getDetectWord())
|
||||
.build();
|
||||
//3. add to mapInfo
|
||||
SchemaElementMatch schemaElementMatch =
|
||||
SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.word(matchResult.getName())
|
||||
.similarity(1 - matchResult.getDistance())
|
||||
.detectWord(matchResult.getDetectWord())
|
||||
.build();
|
||||
// 3. add to mapInfo
|
||||
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,12 @@ package com.tencent.supersonic.headless.chat.mapper;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -30,15 +30,14 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_THRESHOLD_MIN;
|
||||
|
||||
/**
|
||||
* EmbeddingMatchStrategy uses vector database to perform
|
||||
* similarity search against the embeddings of schema elements.
|
||||
* EmbeddingMatchStrategy uses vector database to perform similarity search against the embeddings
|
||||
* of schema elements.
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
|
||||
@Autowired
|
||||
private MetaEmbeddingService metaEmbeddingService;
|
||||
@Autowired private MetaEmbeddingService metaEmbeddingService;
|
||||
|
||||
@Override
|
||||
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||
@@ -52,44 +51,54 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(ChatQueryContext chatQueryContext, Set<EmbeddingResult> existResults,
|
||||
Set<Long> detectDataSetIds, String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
public void detectByStep(
|
||||
ChatQueryContext chatQueryContext,
|
||||
Set<EmbeddingResult> existResults,
|
||||
Set<Long> detectDataSetIds,
|
||||
String detectSegment,
|
||||
int offset) {}
|
||||
|
||||
@Override
|
||||
public List<EmbeddingResult> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
public List<EmbeddingResult> detect(
|
||||
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||
String text = chatQueryContext.getQueryText();
|
||||
Set<String> detectSegments = new HashSet<>();
|
||||
|
||||
int embeddingTextSize = Integer.valueOf(
|
||||
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE));
|
||||
int embeddingTextSize =
|
||||
Integer.valueOf(
|
||||
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE));
|
||||
|
||||
int embeddingTextStep = Integer.valueOf(
|
||||
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP));
|
||||
int embeddingTextStep =
|
||||
Integer.valueOf(
|
||||
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP));
|
||||
|
||||
for (int startIndex = 0; startIndex < text.length(); startIndex += embeddingTextStep) {
|
||||
int endIndex = Math.min(startIndex + embeddingTextSize, text.length());
|
||||
String detectSegment = text.substring(startIndex, endIndex).trim();
|
||||
detectSegments.add(detectSegment);
|
||||
}
|
||||
Set<EmbeddingResult> results = detectByBatch(chatQueryContext, detectDataSetIds, detectSegments);
|
||||
Set<EmbeddingResult> results =
|
||||
detectByBatch(chatQueryContext, detectDataSetIds, detectSegments);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected Set<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
|
||||
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
||||
protected Set<EmbeddingResult> detectByBatch(
|
||||
ChatQueryContext chatQueryContext,
|
||||
Set<Long> detectDataSetIds,
|
||||
Set<String> detectSegments) {
|
||||
Set<EmbeddingResult> results = new HashSet<>();
|
||||
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
||||
int embeddingMapperBatch =
|
||||
Integer.valueOf(
|
||||
mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
||||
|
||||
List<String> queryTextsList = detectSegments.stream()
|
||||
.map(detectSegment -> detectSegment.trim())
|
||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment))
|
||||
.collect(Collectors.toList());
|
||||
List<String> queryTextsList =
|
||||
detectSegments.stream()
|
||||
.map(detectSegment -> detectSegment.trim())
|
||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
||||
embeddingMapperBatch);
|
||||
List<List<String>> queryTextsSubList =
|
||||
Lists.partition(queryTextsList, embeddingMapperBatch);
|
||||
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, chatQueryContext);
|
||||
@@ -97,60 +106,99 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
return results;
|
||||
}
|
||||
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||
List<String> queryTextsSub, ChatQueryContext chatQueryContext) {
|
||||
private void detectByQueryTextsSub(
|
||||
Set<EmbeddingResult> results,
|
||||
Set<Long> detectDataSetIds,
|
||||
List<String> queryTextsSub,
|
||||
ChatQueryContext chatQueryContext) {
|
||||
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
|
||||
double embeddingThreshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
||||
double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
|
||||
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, chatQueryContext.getMapModeEnum());
|
||||
double embeddingThreshold =
|
||||
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
||||
double embeddingThresholdMin =
|
||||
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
|
||||
double threshold =
|
||||
getThreshold(
|
||||
embeddingThreshold,
|
||||
embeddingThresholdMin,
|
||||
chatQueryContext.getMapModeEnum());
|
||||
|
||||
// step1. build query params
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
|
||||
// step2. retrieveQuery by detectSegment
|
||||
int embeddingNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
|
||||
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
|
||||
int embeddingNumber =
|
||||
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
|
||||
List<RetrieveQueryResult> retrieveQueryResults =
|
||||
metaEmbeddingService.retrieveQuery(
|
||||
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
}
|
||||
// step3. build EmbeddingResults
|
||||
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
||||
.map(retrieveQueryResult -> {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||
retrievals.removeIf(retrieval -> {
|
||||
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
|
||||
return retrieval.getDistance() > 1 - threshold;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
return retrieveQueryResult;
|
||||
})
|
||||
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
||||
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
|
||||
.map(retrieval -> {
|
||||
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
||||
embeddingResult.setName(retrieval.getQuery());
|
||||
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toString()));
|
||||
embeddingResult.setMetadata(convertedMap);
|
||||
return embeddingResult;
|
||||
}))
|
||||
.collect(Collectors.toList());
|
||||
List<EmbeddingResult> collect =
|
||||
retrieveQueryResults.stream()
|
||||
.map(
|
||||
retrieveQueryResult -> {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||
retrievals.removeIf(
|
||||
retrieval -> {
|
||||
if (!retrieveQueryResult
|
||||
.getQuery()
|
||||
.contains(retrieval.getQuery())) {
|
||||
return retrieval.getDistance()
|
||||
> 1 - threshold;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
return retrieveQueryResult;
|
||||
})
|
||||
.filter(
|
||||
retrieveQueryResult ->
|
||||
CollectionUtils.isNotEmpty(
|
||||
retrieveQueryResult.getRetrieval()))
|
||||
.flatMap(
|
||||
retrieveQueryResult ->
|
||||
retrieveQueryResult.getRetrieval().stream()
|
||||
.map(
|
||||
retrieval -> {
|
||||
EmbeddingResult embeddingResult =
|
||||
new EmbeddingResult();
|
||||
BeanUtils.copyProperties(
|
||||
retrieval, embeddingResult);
|
||||
embeddingResult.setDetectWord(
|
||||
retrieveQueryResult.getQuery());
|
||||
embeddingResult.setName(
|
||||
retrieval.getQuery());
|
||||
Map<String, String> convertedMap =
|
||||
retrieval.getMetadata()
|
||||
.entrySet().stream()
|
||||
.collect(
|
||||
Collectors
|
||||
.toMap(
|
||||
Map
|
||||
.Entry
|
||||
::getKey,
|
||||
entry ->
|
||||
entry.getValue()
|
||||
.toString()));
|
||||
embeddingResult.setMetadata(
|
||||
convertedMap);
|
||||
return embeddingResult;
|
||||
}))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// step4. select mapResul in one round
|
||||
int embeddingRoundNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
|
||||
int embeddingRoundNumber =
|
||||
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
|
||||
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
|
||||
List<EmbeddingResult> oneRoundResults = collect.stream()
|
||||
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
||||
.limit(roundNumber)
|
||||
.collect(Collectors.toList());
|
||||
List<EmbeddingResult> oneRoundResults =
|
||||
collect.stream()
|
||||
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
||||
.limit(roundNumber)
|
||||
.collect(Collectors.toList());
|
||||
selectResultInOneRound(results, oneRoundResults);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -14,9 +14,7 @@ import org.springframework.util.CollectionUtils;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* A mapper capable of converting the VALUE of entity dimension values into ID types.
|
||||
*/
|
||||
/** A mapper capable of converting the VALUE of entity dimension values into ID types. */
|
||||
@Slf4j
|
||||
public class EntityMapper extends BaseMapper {
|
||||
|
||||
@@ -24,7 +22,8 @@ public class EntityMapper extends BaseMapper {
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo();
|
||||
for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(dataSetId);
|
||||
List<SchemaElementMatch> schemaElementMatchList =
|
||||
schemaMapInfo.getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||
continue;
|
||||
}
|
||||
@@ -32,15 +31,19 @@ public class EntityMapper extends BaseMapper {
|
||||
if (entity == null || entity.getId() == null) {
|
||||
continue;
|
||||
}
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
|
||||
.filter(schemaElementMatch -> SchemaElementType.VALUE.equals(
|
||||
schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
List<SchemaElementMatch> valueSchemaElements =
|
||||
schemaElementMatchList.stream()
|
||||
.filter(
|
||||
schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(
|
||||
schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
||||
continue;
|
||||
}
|
||||
if (!checkExistSameEntitySchemaElements(schemaElementMatch, schemaElementMatchList)) {
|
||||
if (!checkExistSameEntitySchemaElements(
|
||||
schemaElementMatch, schemaElementMatchList)) {
|
||||
SchemaElementMatch entitySchemaElementMath = new SchemaElementMatch();
|
||||
BeanUtils.copyProperties(schemaElementMatch, entitySchemaElementMath);
|
||||
entitySchemaElementMath.setElement(entity);
|
||||
@@ -51,13 +54,21 @@ public class EntityMapper extends BaseMapper {
|
||||
}
|
||||
}
|
||||
|
||||
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
|
||||
private boolean checkExistSameEntitySchemaElements(
|
||||
SchemaElementMatch valueSchemaElementMatch,
|
||||
List<SchemaElementMatch> schemaElementMatchList) {
|
||||
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
List<SchemaElementMatch> entitySchemaElements =
|
||||
schemaElementMatchList.stream()
|
||||
.filter(
|
||||
schemaElementMatch ->
|
||||
SchemaElementType.ENTITY.equals(
|
||||
schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : entitySchemaElements) {
|
||||
if (schemaElementMatch.getElement().getId().equals(valueSchemaElementMatch.getElement().getId())) {
|
||||
if (schemaElementMatch
|
||||
.getElement()
|
||||
.getId()
|
||||
.equals(valueSchemaElementMatch.getElement().getId())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,20 +25,18 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.MAPPER_DE
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.MAPPER_DIMENSION_VALUE_SIZE;
|
||||
|
||||
/**
|
||||
* HanlpDictMatchStrategy uses <a href="https://www.hanlp.com/">HanLP</a> to
|
||||
* match schema elements. It currently supports prefix and suffix matching
|
||||
* against names, values and aliases.
|
||||
* HanlpDictMatchStrategy uses <a href="https://www.hanlp.com/">HanLP</a> to match schema elements.
|
||||
* It currently supports prefix and suffix matching against names, values and aliases.
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Autowired
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
@Autowired private KnowledgeBaseService knowledgeBaseService;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(ChatQueryContext chatQueryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
public Map<MatchText, List<HanlpMapResult>> match(
|
||||
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||
String text = chatQueryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
@@ -59,18 +57,34 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
public void detectByStep(ChatQueryContext chatQueryContext, Set<HanlpMapResult> existResults,
|
||||
Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
public void detectByStep(
|
||||
ChatQueryContext chatQueryContext,
|
||||
Set<HanlpMapResult> existResults,
|
||||
Set<Long> detectDataSetIds,
|
||||
String detectSegment,
|
||||
int offset) {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_MAX_SIZE));
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
||||
oneDetectionMaxSize, chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
Integer oneDetectionMaxSize =
|
||||
Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_MAX_SIZE));
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults =
|
||||
knowledgeBaseService
|
||||
.prefixSearch(
|
||||
detectSegment,
|
||||
oneDetectionMaxSize,
|
||||
chatQueryContext.getModelIdToDataSetIds(),
|
||||
detectDataSetIds)
|
||||
.stream()
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(detectSegment,
|
||||
oneDetectionMaxSize, chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults =
|
||||
knowledgeBaseService
|
||||
.suffixSearch(
|
||||
detectSegment,
|
||||
oneDetectionMaxSize,
|
||||
chatQueryContext.getModelIdToDataSetIds(),
|
||||
detectDataSetIds)
|
||||
.stream()
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
|
||||
@@ -78,45 +92,67 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
return;
|
||||
}
|
||||
// step3. merge pre/suffix result
|
||||
hanlpMapResults = hanlpMapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
hanlpMapResults =
|
||||
hanlpMapResults.stream()
|
||||
.sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step4. filter by similarity
|
||||
hanlpMapResults = hanlpMapResults.stream()
|
||||
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||
>= getThresholdMatch(term.getNatures(), chatQueryContext))
|
||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
hanlpMapResults =
|
||||
hanlpMapResults.stream()
|
||||
.filter(
|
||||
term ->
|
||||
mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||
>= getThresholdMatch(
|
||||
term.getNatures(), chatQueryContext))
|
||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
log.debug("detectSegment:{},after isSimilarity parseResults:{}", detectSegment, hanlpMapResults);
|
||||
log.debug(
|
||||
"detectSegment:{},after isSimilarity parseResults:{}",
|
||||
detectSegment,
|
||||
hanlpMapResults);
|
||||
|
||||
hanlpMapResults = hanlpMapResults.stream().map(parseResult -> {
|
||||
parseResult.setOffset(offset);
|
||||
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
|
||||
return parseResult;
|
||||
}).collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
hanlpMapResults =
|
||||
hanlpMapResults.stream()
|
||||
.map(
|
||||
parseResult -> {
|
||||
parseResult.setOffset(offset);
|
||||
parseResult.setSimilarity(
|
||||
mapperHelper.getSimilarity(
|
||||
detectSegment, parseResult.getName()));
|
||||
return parseResult;
|
||||
})
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step5. take only M dimensionValue or N-M metric/dimension value per rond.
|
||||
int oneDetectionValueSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE));
|
||||
List<HanlpMapResult> dimensionValues = hanlpMapResults.stream()
|
||||
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
||||
.limit(oneDetectionValueSize)
|
||||
.collect(Collectors.toList());
|
||||
int oneDetectionValueSize =
|
||||
Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE));
|
||||
List<HanlpMapResult> dimensionValues =
|
||||
hanlpMapResults.stream()
|
||||
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
||||
.limit(oneDetectionValueSize)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Integer oneDetectionSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_SIZE));
|
||||
Integer oneDetectionSize =
|
||||
Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_SIZE));
|
||||
List<HanlpMapResult> oneRoundResults = new ArrayList<>();
|
||||
|
||||
// add the dimensionValue if it exists
|
||||
if (CollectionUtils.isNotEmpty(dimensionValues)) {
|
||||
oneRoundResults.addAll(dimensionValues);
|
||||
}
|
||||
// fill the rest of the list with other results, excluding the dimensionValue if it was added
|
||||
// fill the rest of the list with other results, excluding the dimensionValue if it was
|
||||
// added
|
||||
if (oneRoundResults.size() < oneDetectionSize) {
|
||||
List<HanlpMapResult> additionalResults = hanlpMapResults.stream()
|
||||
.filter(entry -> !mapperHelper.existDimensionValues(entry.getNatures())
|
||||
&& !oneRoundResults.contains(entry))
|
||||
.limit(oneDetectionSize - oneRoundResults.size())
|
||||
.collect(Collectors.toList());
|
||||
List<HanlpMapResult> additionalResults =
|
||||
hanlpMapResults.stream()
|
||||
.filter(
|
||||
entry ->
|
||||
!mapperHelper.existDimensionValues(entry.getNatures())
|
||||
&& !oneRoundResults.contains(entry))
|
||||
.limit(oneDetectionSize - oneRoundResults.size())
|
||||
.collect(Collectors.toList());
|
||||
oneRoundResults.addAll(additionalResults);
|
||||
}
|
||||
// step6. select mapResul in one round
|
||||
@@ -128,14 +164,21 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}
|
||||
|
||||
public double getThresholdMatch(List<String> natures, ChatQueryContext chatQueryContext) {
|
||||
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
||||
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
|
||||
Double threshold =
|
||||
Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
||||
Double minThreshold =
|
||||
Double.valueOf(
|
||||
mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
|
||||
if (mapperHelper.existDimensionValues(natures)) {
|
||||
threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD));
|
||||
minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD_MIN));
|
||||
threshold =
|
||||
Double.valueOf(
|
||||
mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD));
|
||||
minThreshold =
|
||||
Double.valueOf(
|
||||
mapperConfig.getParameterValue(
|
||||
MapperConfig.MAPPER_VALUE_THRESHOLD_MIN));
|
||||
}
|
||||
|
||||
return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -22,9 +22,9 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/***
|
||||
* A mapper that recognizes schema elements with keyword.
|
||||
* It leverages two matching strategies: HanlpDictMatchStrategy and DatabaseMatchStrategy.
|
||||
/**
|
||||
* * A mapper that recognizes schema elements with keyword. It leverages two matching strategies:
|
||||
* HanlpDictMatchStrategy and DatabaseMatchStrategy.
|
||||
*/
|
||||
@Slf4j
|
||||
public class KeywordMapper extends BaseMapper {
|
||||
@@ -32,29 +32,40 @@ public class KeywordMapper extends BaseMapper {
|
||||
@Override
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
String queryText = chatQueryContext.getQueryText();
|
||||
//1.hanlpDict Match
|
||||
List<S2Term> terms = HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
||||
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||
// 1.hanlpDict Match
|
||||
List<S2Term> terms =
|
||||
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
||||
HanlpDictMatchStrategy hanlpMatchStrategy =
|
||||
ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||
|
||||
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(chatQueryContext, terms);
|
||||
List<HanlpMapResult> hanlpMapResults =
|
||||
hanlpMatchStrategy.getMatches(chatQueryContext, terms);
|
||||
convertHanlpMapResultToMapInfo(hanlpMapResults, chatQueryContext, terms);
|
||||
|
||||
//2.database Match
|
||||
DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class);
|
||||
// 2.database Match
|
||||
DatabaseMatchStrategy databaseMatchStrategy =
|
||||
ContextUtils.getBean(DatabaseMatchStrategy.class);
|
||||
|
||||
List<DatabaseMapResult> databaseResults = databaseMatchStrategy.getMatches(chatQueryContext, terms);
|
||||
List<DatabaseMapResult> databaseResults =
|
||||
databaseMatchStrategy.getMatches(chatQueryContext, terms);
|
||||
convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults);
|
||||
}
|
||||
|
||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, ChatQueryContext chatQueryContext,
|
||||
List<S2Term> terms) {
|
||||
private void convertHanlpMapResultToMapInfo(
|
||||
List<HanlpMapResult> mapResults,
|
||||
ChatQueryContext chatQueryContext,
|
||||
List<S2Term> terms) {
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
return;
|
||||
}
|
||||
HanlpHelper.transLetterOriginal(mapResults);
|
||||
Map<String, Long> wordNatureToFrequency = terms.stream().collect(
|
||||
Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
|
||||
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
|
||||
Map<String, Long> wordNatureToFrequency =
|
||||
terms.stream()
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
entry -> entry.getWord() + entry.getNature(),
|
||||
term -> Long.valueOf(term.getFrequency()),
|
||||
(value1, value2) -> value2));
|
||||
|
||||
for (HanlpMapResult hanlpMapResult : mapResults) {
|
||||
for (String nature : hanlpMapResult.getNatures()) {
|
||||
@@ -67,55 +78,70 @@ public class KeywordMapper extends BaseMapper {
|
||||
continue;
|
||||
}
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
SchemaElement element = getSchemaElement(dataSetId, elementType,
|
||||
elementID, chatQueryContext.getSemanticSchema());
|
||||
SchemaElement element =
|
||||
getSchemaElement(
|
||||
dataSetId,
|
||||
elementType,
|
||||
elementID,
|
||||
chatQueryContext.getSemanticSchema());
|
||||
if (element == null) {
|
||||
continue;
|
||||
}
|
||||
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
.frequency(frequency)
|
||||
.word(hanlpMapResult.getName())
|
||||
.similarity(hanlpMapResult.getSimilarity())
|
||||
.detectWord(hanlpMapResult.getDetectWord())
|
||||
.build();
|
||||
SchemaElementMatch schemaElementMatch =
|
||||
SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
.frequency(frequency)
|
||||
.word(hanlpMapResult.getName())
|
||||
.similarity(hanlpMapResult.getSimilarity())
|
||||
.detectWord(hanlpMapResult.getDetectWord())
|
||||
.build();
|
||||
|
||||
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void convertDatabaseMapResultToMapInfo(ChatQueryContext chatQueryContext,
|
||||
List<DatabaseMapResult> mapResults) {
|
||||
private void convertDatabaseMapResultToMapInfo(
|
||||
ChatQueryContext chatQueryContext, List<DatabaseMapResult> mapResults) {
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
for (DatabaseMapResult match : mapResults) {
|
||||
SchemaElement schemaElement = match.getSchemaElement();
|
||||
Set<Long> regElementSet = getRegElementSet(chatQueryContext.getMapInfo(), schemaElement);
|
||||
Set<Long> regElementSet =
|
||||
getRegElementSet(chatQueryContext.getMapInfo(), schemaElement);
|
||||
if (regElementSet.contains(schemaElement.getId())) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.word(schemaElement.getName())
|
||||
.detectWord(match.getDetectWord())
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
SchemaElementMatch schemaElementMatch =
|
||||
SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.word(schemaElement.getName())
|
||||
.detectWord(match.getDetectWord())
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.similarity(
|
||||
mapperHelper.getSimilarity(
|
||||
match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||
addToSchemaMap(chatQueryContext.getMapInfo(), schemaElement.getDataSetId(), schemaElementMatch);
|
||||
addToSchemaMap(
|
||||
chatQueryContext.getMapInfo(),
|
||||
schemaElement.getDataSetId(),
|
||||
schemaElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDataSetId());
|
||||
List<SchemaElementMatch> elements =
|
||||
schemaMap.getMatchedElements(schemaElement.getDataSetId());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
return elements.stream()
|
||||
.filter(elementMatch ->
|
||||
SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|
||||
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
|
||||
.filter(
|
||||
elementMatch ->
|
||||
SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|
||||
|| SchemaElementType.DIMENSION.equals(
|
||||
elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.getElement().getId())
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
|
||||
@@ -8,87 +8,128 @@ import org.springframework.stereotype.Service;
|
||||
public class MapperConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter MAPPER_DETECTION_SIZE =
|
||||
new Parameter("s2.mapper.detection.size", "8",
|
||||
new Parameter(
|
||||
"s2.mapper.detection.size",
|
||||
"8",
|
||||
"一次探测返回结果个数",
|
||||
"在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_DETECTION_MAX_SIZE =
|
||||
new Parameter("s2.mapper.detection.max.size", "20",
|
||||
new Parameter(
|
||||
"s2.mapper.detection.max.size",
|
||||
"20",
|
||||
"一次探测前后缀匹配结果返回个数",
|
||||
"单次前后缀匹配返回的结果个数",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_NAME_THRESHOLD =
|
||||
new Parameter("s2.mapper.name.threshold", "0.3",
|
||||
new Parameter(
|
||||
"s2.mapper.name.threshold",
|
||||
"0.3",
|
||||
"指标名、维度名文本相似度阈值",
|
||||
"文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_NAME_THRESHOLD_MIN =
|
||||
new Parameter("s2.mapper.name.min.threshold", "0.25",
|
||||
new Parameter(
|
||||
"s2.mapper.name.min.threshold",
|
||||
"0.25",
|
||||
"指标名、维度名最小文本相似度阈值",
|
||||
"指标名、维度名相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_DIMENSION_VALUE_SIZE =
|
||||
new Parameter("s2.mapper.value.size", "1",
|
||||
new Parameter(
|
||||
"s2.mapper.value.size",
|
||||
"1",
|
||||
"一次探测返回维度值结果个数",
|
||||
"在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的维度值结果个数",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_VALUE_THRESHOLD =
|
||||
new Parameter("s2.mapper.value.threshold", "0.5",
|
||||
new Parameter(
|
||||
"s2.mapper.value.threshold",
|
||||
"0.5",
|
||||
"维度值文本相似度阈值",
|
||||
"文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_VALUE_THRESHOLD_MIN =
|
||||
new Parameter("s2.mapper.value.min.threshold", "0.3",
|
||||
new Parameter(
|
||||
"s2.mapper.value.min.threshold",
|
||||
"0.3",
|
||||
"维度值最小文本相似度阈值",
|
||||
"维度值相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_TEXT_SIZE =
|
||||
new Parameter("s2.mapper.embedding.word.size", "4",
|
||||
new Parameter(
|
||||
"s2.mapper.embedding.word.size",
|
||||
"4",
|
||||
"用于向量召回文本长度",
|
||||
"为提高向量召回效率, 按指定长度进行向量语义召回",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_TEXT_STEP =
|
||||
new Parameter("s2.mapper.embedding.word.step", "3",
|
||||
new Parameter(
|
||||
"s2.mapper.embedding.word.step",
|
||||
"3",
|
||||
"向量召回文本每步长度",
|
||||
"为提高向量召回效率, 按指定每步长度进行召回",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_BATCH =
|
||||
new Parameter("s2.mapper.embedding.batch", "50",
|
||||
new Parameter(
|
||||
"s2.mapper.embedding.batch",
|
||||
"50",
|
||||
"批量向量召回文本请求个数",
|
||||
"每次进行向量语义召回的原始文本片段个数",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_NUMBER =
|
||||
new Parameter("s2.mapper.embedding.number", "5",
|
||||
new Parameter(
|
||||
"s2.mapper.embedding.number",
|
||||
"5",
|
||||
"批量向量召回文本返回结果个数",
|
||||
"每个文本进行向量语义召回的文本结果个数",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
|
||||
new Parameter("s2.mapper.embedding.threshold", "0.98",
|
||||
new Parameter(
|
||||
"s2.mapper.embedding.threshold",
|
||||
"0.98",
|
||||
"向量召回相似度阈值",
|
||||
"相似度小于该阈值的则舍弃",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_THRESHOLD_MIN =
|
||||
new Parameter("s2.mapper.embedding.min.threshold", "0.9",
|
||||
new Parameter(
|
||||
"s2.mapper.embedding.min.threshold",
|
||||
"0.9",
|
||||
"向量召回最小相似度阈值",
|
||||
"向量召回相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
|
||||
new Parameter("s2.mapper.embedding.round.number", "10",
|
||||
new Parameter(
|
||||
"s2.mapper.embedding.round.number",
|
||||
"10",
|
||||
"向量召回最小相似度阈值",
|
||||
"向量召回相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
"number",
|
||||
"Mapper相关配置");
|
||||
}
|
||||
|
||||
@@ -29,8 +29,11 @@ public class MapperHelper {
|
||||
}
|
||||
|
||||
public Integer getStepOffset(List<S2Term> termList, Integer index) {
|
||||
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(S2Term::getOffset))
|
||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||
List<Integer> offsetList =
|
||||
termList.stream()
|
||||
.sorted(Comparator.comparing(S2Term::getOffset))
|
||||
.map(term -> term.getOffset())
|
||||
.collect(Collectors.toList());
|
||||
|
||||
for (int j = 0; j < termList.size() - 1; j++) {
|
||||
if (offsetList.get(j) <= index && offsetList.get(j + 1) > index) {
|
||||
@@ -40,8 +43,9 @@ public class MapperHelper {
|
||||
return index;
|
||||
}
|
||||
|
||||
/***
|
||||
* exist dimension values
|
||||
/**
|
||||
* * exist dimension values
|
||||
*
|
||||
* @param natures
|
||||
* @return
|
||||
*/
|
||||
@@ -63,8 +67,9 @@ public class MapperHelper {
|
||||
return false;
|
||||
}
|
||||
|
||||
/***
|
||||
* get similarity
|
||||
/**
|
||||
* * get similarity
|
||||
*
|
||||
* @param detectSegment
|
||||
* @param matchName
|
||||
* @return
|
||||
@@ -72,7 +77,8 @@ public class MapperHelper {
|
||||
public double getSimilarity(String detectSegment, String matchName) {
|
||||
String detectSegmentLower = detectSegment == null ? null : detectSegment.toLowerCase();
|
||||
String matchNameLower = matchName == null ? null : matchName.toLowerCase();
|
||||
return 1 - (double) EditDistance.compute(detectSegmentLower, matchNameLower) / Math.max(matchName.length(),
|
||||
detectSegment.length());
|
||||
return 1
|
||||
- (double) EditDistance.compute(detectSegmentLower, matchNameLower)
|
||||
/ Math.max(matchName.length(), detectSegment.length());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
|
||||
@@ -9,11 +8,10 @@ import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* MatchStrategy encapsulates a concrete matching algorithm
|
||||
* executed during query or search process.
|
||||
* MatchStrategy encapsulates a concrete matching algorithm executed during query or search process.
|
||||
*/
|
||||
public interface MatchStrategy<T> {
|
||||
|
||||
Map<MatchText, List<T>> match(ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
|
||||
|
||||
}
|
||||
Map<MatchText, List<T>> match(
|
||||
ChatQueryContext chatQueryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
|
||||
}
|
||||
|
||||
@@ -24,7 +24,8 @@ public class MatchText {
|
||||
return false;
|
||||
}
|
||||
MatchText that = (MatchText) o;
|
||||
return Objects.equals(regText, that.regText) && Objects.equals(detectSegment, that.detectSegment);
|
||||
return Objects.equals(regText, that.regText)
|
||||
&& Objects.equals(detectSegment, that.detectSegment);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -17,4 +17,4 @@ public class ModelWithSemanticType implements Serializable {
|
||||
this.model = model;
|
||||
this.schemaElementType = schemaElementType;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,8 @@ public class QueryFilterMapper extends BaseMapper {
|
||||
SchemaMapInfo schemaMapInfo = chatQueryContext.getMapInfo();
|
||||
clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo);
|
||||
for (Long dataSetId : dataSetIds) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(dataSetId);
|
||||
List<SchemaElementMatch> schemaElementMatches =
|
||||
schemaMapInfo.getMatchedElements(dataSetId);
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
schemaMapInfo.setMatchedElements(dataSetId, schemaElementMatches);
|
||||
@@ -42,14 +43,17 @@ public class QueryFilterMapper extends BaseMapper {
|
||||
}
|
||||
|
||||
private void clearOtherSchemaElementMatch(Set<Long> viewIds, SchemaMapInfo schemaMapInfo) {
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getDataSetElementMatches().entrySet()) {
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry :
|
||||
schemaMapInfo.getDataSetElementMatches().entrySet()) {
|
||||
if (!viewIds.contains(entry.getKey())) {
|
||||
entry.getValue().clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addValueSchemaElementMatch(Long dataSetId, ChatQueryContext chatQueryContext,
|
||||
private void addValueSchemaElementMatch(
|
||||
Long dataSetId,
|
||||
ChatQueryContext chatQueryContext,
|
||||
List<SchemaElementMatch> candidateElementMatches) {
|
||||
QueryFilters queryFilters = chatQueryContext.getQueryFilters();
|
||||
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||
@@ -59,33 +63,41 @@ public class QueryFilterMapper extends BaseMapper {
|
||||
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElement element = SchemaElement.builder()
|
||||
.id(filter.getElementID())
|
||||
.name(String.valueOf(filter.getValue()))
|
||||
.type(SchemaElementType.VALUE)
|
||||
.bizName(filter.getBizName())
|
||||
.dataSetId(dataSetId)
|
||||
.build();
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.word(String.valueOf(filter.getValue()))
|
||||
.similarity(similarity)
|
||||
.detectWord(Constants.EMPTY)
|
||||
.build();
|
||||
SchemaElement element =
|
||||
SchemaElement.builder()
|
||||
.id(filter.getElementID())
|
||||
.name(String.valueOf(filter.getValue()))
|
||||
.type(SchemaElementType.VALUE)
|
||||
.bizName(filter.getBizName())
|
||||
.dataSetId(dataSetId)
|
||||
.build();
|
||||
SchemaElementMatch schemaElementMatch =
|
||||
SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.word(String.valueOf(filter.getValue()))
|
||||
.similarity(similarity)
|
||||
.detectWord(Constants.EMPTY)
|
||||
.build();
|
||||
candidateElementMatches.add(schemaElementMatch);
|
||||
}
|
||||
chatQueryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches);
|
||||
}
|
||||
|
||||
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
|
||||
List<SchemaElementMatch> schemaElementMatches) {
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatches.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
private boolean checkExistSameValueSchemaElementMatch(
|
||||
QueryFilter queryFilter, List<SchemaElementMatch> schemaElementMatches) {
|
||||
List<SchemaElementMatch> valueSchemaElements =
|
||||
schemaElementMatches.stream()
|
||||
.filter(
|
||||
schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(
|
||||
schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||
if (schemaElementMatch.getElement().getId().equals(queryFilter.getElementID())
|
||||
&& schemaElementMatch.getWord().equals(String.valueOf(queryFilter.getValue()))) {
|
||||
&& schemaElementMatch
|
||||
.getWord()
|
||||
.equals(String.valueOf(queryFilter.getValue()))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
|
||||
/**
|
||||
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)
|
||||
* in user queries. It matches the query text against the knowledge base.
|
||||
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values) in
|
||||
* user queries. It matches the query text against the knowledge base.
|
||||
*/
|
||||
public interface SchemaMapper {
|
||||
|
||||
|
||||
@@ -20,20 +20,18 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* SearchMatchStrategy encapsulates a concrete matching algorithm
|
||||
* executed during search process.
|
||||
* SearchMatchStrategy encapsulates a concrete matching algorithm executed during search process.
|
||||
*/
|
||||
@Service
|
||||
public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
private static final int SEARCH_SIZE = 3;
|
||||
|
||||
@Autowired
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
@Autowired private KnowledgeBaseService knowledgeBaseService;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(ChatQueryContext chatQueryContext, List<S2Term> originals,
|
||||
Set<Long> detectDataSetIds) {
|
||||
public Map<MatchText, List<HanlpMapResult>> match(
|
||||
ChatQueryContext chatQueryContext, List<S2Term> originals, Set<Long> detectDataSetIds) {
|
||||
String text = chatQueryContext.getQueryText();
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||
|
||||
@@ -52,39 +50,58 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}
|
||||
}
|
||||
Map<MatchText, List<HanlpMapResult>> regTextMap = new ConcurrentHashMap<>();
|
||||
detectIndexList.stream().parallel().forEach(detectIndex -> {
|
||||
String regText = text.substring(0, detectIndex);
|
||||
String detectSegment = text.substring(detectIndex);
|
||||
detectIndexList.stream()
|
||||
.parallel()
|
||||
.forEach(
|
||||
detectIndex -> {
|
||||
String regText = text.substring(0, detectIndex);
|
||||
String detectSegment = text.substring(detectIndex);
|
||||
|
||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||
List<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
||||
SearchService.SEARCH_SIZE,
|
||||
chatQueryContext.getModelIdToDataSetIds(),
|
||||
detectDataSetIds);
|
||||
List<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(
|
||||
detectSegment,
|
||||
SEARCH_SIZE,
|
||||
chatQueryContext.getModelIdToDataSetIds(),
|
||||
detectDataSetIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||
List<String> natures = entry.getNatures().stream()
|
||||
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
|
||||
.collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(natures)) {
|
||||
return false;
|
||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||
List<HanlpMapResult> hanlpMapResults =
|
||||
knowledgeBaseService.prefixSearch(
|
||||
detectSegment,
|
||||
SearchService.SEARCH_SIZE,
|
||||
chatQueryContext.getModelIdToDataSetIds(),
|
||||
detectDataSetIds);
|
||||
List<HanlpMapResult> suffixHanlpMapResults =
|
||||
knowledgeBaseService.suffixSearch(
|
||||
detectSegment,
|
||||
SEARCH_SIZE,
|
||||
chatQueryContext.getModelIdToDataSetIds(),
|
||||
detectDataSetIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
hanlpMapResults =
|
||||
hanlpMapResults.stream()
|
||||
.filter(
|
||||
entry -> {
|
||||
List<String> natures =
|
||||
entry.getNatures().stream()
|
||||
.filter(
|
||||
nature ->
|
||||
!nature
|
||||
.endsWith(
|
||||
DictWordType
|
||||
.ENTITY
|
||||
.getType()))
|
||||
.collect(
|
||||
Collectors
|
||||
.toList());
|
||||
if (CollectionUtils.isEmpty(natures)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
MatchText matchText =
|
||||
MatchText.builder()
|
||||
.regText(regText)
|
||||
.detectSegment(detectSegment)
|
||||
.build();
|
||||
regTextMap.put(matchText, hanlpMapResults);
|
||||
}
|
||||
return true;
|
||||
}).collect(Collectors.toList());
|
||||
MatchText matchText = MatchText.builder()
|
||||
.regText(regText)
|
||||
.detectSegment(detectSegment)
|
||||
.build();
|
||||
regTextMap.put(matchText, hanlpMapResults);
|
||||
}
|
||||
}
|
||||
);
|
||||
});
|
||||
return regTextMap;
|
||||
}
|
||||
|
||||
@@ -99,9 +116,10 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(ChatQueryContext chatQueryContext, Set<HanlpMapResult> existResults,
|
||||
Set<Long> detectDataSetIds, String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
public void detectByStep(
|
||||
ChatQueryContext chatQueryContext,
|
||||
Set<HanlpMapResult> existResults,
|
||||
Set<Long> detectDataSetIds,
|
||||
String detectSegment,
|
||||
int offset) {}
|
||||
}
|
||||
|
||||
@@ -5,17 +5,17 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/***
|
||||
* A mapper that map the description of the term.
|
||||
*/
|
||||
/** * A mapper that map the description of the term. */
|
||||
@Slf4j
|
||||
public class TermDescMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(ChatQueryContext chatQueryContext) {
|
||||
List<SchemaElement> termDescriptionToMap = chatQueryContext.getMapInfo().getTermDescriptionToMap();
|
||||
List<SchemaElement> termDescriptionToMap =
|
||||
chatQueryContext.getMapInfo().getTermDescriptionToMap();
|
||||
if (CollectionUtils.isEmpty(termDescriptionToMap)) {
|
||||
return;
|
||||
}
|
||||
@@ -37,5 +37,4 @@ public class TermDescMapper extends BaseMapper {
|
||||
chatQueryContext.setQueryText(chatQueryContext.getOriQueryText());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -13,53 +13,82 @@ import java.util.List;
|
||||
public class ParserConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter PARSER_STRATEGY_TYPE =
|
||||
new Parameter("s2.parser.s2sql.strategy", "ONE_PASS_SELF_CONSISTENCY",
|
||||
new Parameter(
|
||||
"s2.parser.s2sql.strategy",
|
||||
"ONE_PASS_SELF_CONSISTENCY",
|
||||
"LLM解析生成S2SQL策略",
|
||||
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql",
|
||||
"list", "Parser相关配置",
|
||||
"list",
|
||||
"Parser相关配置",
|
||||
Lists.newArrayList("ONE_PASS_SELF_CONSISTENCY"));
|
||||
|
||||
public static final Parameter PARSER_LINKING_VALUE_ENABLE =
|
||||
new Parameter("s2.parser.linking.value.enable", "true",
|
||||
"是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择",
|
||||
"bool", "Parser相关配置");
|
||||
new Parameter(
|
||||
"s2.parser.linking.value.enable",
|
||||
"true",
|
||||
"是否将Mapper探测识别到的维度值提供给大模型",
|
||||
"为了数据安全考虑, 这里可进行开关选择",
|
||||
"bool",
|
||||
"Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD =
|
||||
new Parameter("s2.parser.text.length.threshold", "10",
|
||||
"用户输入文本长短阈值", "文本超过该阈值为长文本",
|
||||
"number", "Parser相关配置");
|
||||
new Parameter(
|
||||
"s2.parser.text.length.threshold",
|
||||
"10",
|
||||
"用户输入文本长短阈值",
|
||||
"文本超过该阈值为长文本",
|
||||
"number",
|
||||
"Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT =
|
||||
new Parameter("s2.parser.text.threshold.short", "0.5",
|
||||
new Parameter(
|
||||
"s2.parser.text.threshold.short",
|
||||
"0.5",
|
||||
"短文本匹配阈值",
|
||||
"由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,"
|
||||
+ "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number", "Parser相关配置");
|
||||
+ "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number",
|
||||
"Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG =
|
||||
new Parameter("s2.parser.text.threshold.long", "0.8",
|
||||
"长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number", "Parser相关配置");
|
||||
new Parameter(
|
||||
"s2.parser.text.threshold.long",
|
||||
"0.8",
|
||||
"长文本匹配阈值",
|
||||
"如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number",
|
||||
"Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER =
|
||||
new Parameter("s2.parser.exemplar-recall.number", "10",
|
||||
"exemplar召回个数", "",
|
||||
"number", "Parser相关配置");
|
||||
new Parameter(
|
||||
"s2.parser.exemplar-recall.number",
|
||||
"10",
|
||||
"exemplar召回个数",
|
||||
"",
|
||||
"number",
|
||||
"Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_FEW_SHOT_NUMBER =
|
||||
new Parameter("s2.parser.few-shot.number", "3",
|
||||
"few-shot样例个数", "样例越多效果可能越好,但token消耗越大",
|
||||
"number", "Parser相关配置");
|
||||
new Parameter(
|
||||
"s2.parser.few-shot.number",
|
||||
"3",
|
||||
"few-shot样例个数",
|
||||
"样例越多效果可能越好,但token消耗越大",
|
||||
"number",
|
||||
"Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER =
|
||||
new Parameter("s2.parser.self-consistency.number", "1",
|
||||
"self-consistency执行个数", "执行越多效果可能越好,但token消耗越大",
|
||||
"number", "Parser相关配置");
|
||||
new Parameter(
|
||||
"s2.parser.self-consistency.number",
|
||||
"1",
|
||||
"self-consistency执行个数",
|
||||
"执行越多效果可能越好,但token消耗越大",
|
||||
"number",
|
||||
"Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_SHOW_COUNT =
|
||||
new Parameter("s2.parser.show.count", "3",
|
||||
"解析结果展示个数", "前端展示的解析个数",
|
||||
"number", "Parser相关配置");
|
||||
new Parameter(
|
||||
"s2.parser.show.count", "3", "解析结果展示个数", "前端展示的解析个数", "number", "Parser相关配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
@@ -67,8 +96,6 @@ public class ParserConfig extends ParameterConfig {
|
||||
PARSER_LINKING_VALUE_ENABLE,
|
||||
PARSER_FEW_SHOT_NUMBER,
|
||||
PARSER_SELF_CONSISTENCY_NUMBER,
|
||||
PARSER_SHOW_COUNT
|
||||
);
|
||||
PARSER_SHOW_COUNT);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
package com.tencent.supersonic.headless.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -22,9 +22,7 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
|
||||
*/
|
||||
/** QueryTypeParser resolves query type as either METRIC or TAG, or ID. */
|
||||
@Slf4j
|
||||
public class QueryTypeParser implements SemanticParser {
|
||||
|
||||
@@ -37,8 +35,8 @@ public class QueryTypeParser implements SemanticParser {
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
// 1.init S2SQL
|
||||
Long dataSetId = semanticQuery.getParseInfo().getDataSetId();
|
||||
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema()
|
||||
.getDataSetSchemaMap().get(dataSetId);
|
||||
DataSetSchema dataSetSchema =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
semanticQuery.initS2Sql(dataSetSchema, user);
|
||||
// 2.set queryType
|
||||
QueryType queryType = getQueryType(chatQueryContext, semanticQuery);
|
||||
@@ -53,23 +51,25 @@ public class QueryTypeParser implements SemanticParser {
|
||||
return QueryType.DETAIL;
|
||||
}
|
||||
|
||||
//1. entity queryType
|
||||
// 1. entity queryType
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getParsedS2SQL());
|
||||
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
|
||||
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(ids) && ids.stream()
|
||||
.anyMatch(whereFilterByTimeFields::contains)) {
|
||||
Set<String> ids =
|
||||
semanticSchema.getEntities(dataSetId).stream()
|
||||
.map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(ids)
|
||||
&& ids.stream().anyMatch(whereFilterByTimeFields::contains)) {
|
||||
return QueryType.ID;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//2. metric queryType
|
||||
// 2. metric queryType
|
||||
if (selectContainsMetric(sqlInfo, dataSetId, semanticSchema)) {
|
||||
return QueryType.METRIC;
|
||||
}
|
||||
@@ -78,20 +78,22 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
|
||||
private static List<String> filterByTimeFields(List<String> whereFields) {
|
||||
List<String> selectAndWhereFilterByTimeFields = whereFields
|
||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
List<String> selectAndWhereFilterByTimeFields =
|
||||
whereFields.stream()
|
||||
.filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
return selectAndWhereFilterByTimeFields;
|
||||
}
|
||||
|
||||
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
|
||||
private static boolean selectContainsMetric(
|
||||
SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
Set<String> metricNameSet =
|
||||
metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
return selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -11,11 +11,9 @@ import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_TE
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD_LONG;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD_SHORT;
|
||||
|
||||
|
||||
/**
|
||||
* This checker can be used by semantic parsers to check if query intent
|
||||
* has already been satisfied by current candidate queries. If so, current
|
||||
* parser could be skipped.
|
||||
* This checker can be used by semantic parsers to check if query intent has already been satisfied
|
||||
* by current candidate queries. If so, current parser could be skipped.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SatisfactionChecker {
|
||||
@@ -51,9 +49,11 @@ public class SatisfactionChecker {
|
||||
} else if (degree < shortTextLengthThreshold) {
|
||||
return false;
|
||||
}
|
||||
log.info("queryMode:{}, degree:{}, parse info:{}",
|
||||
semanticParseInfo.getQueryMode(), degree, semanticParseInfo);
|
||||
log.info(
|
||||
"queryMode:{}, degree:{}, parse info:{}",
|
||||
semanticParseInfo.getQueryMode(),
|
||||
degree,
|
||||
semanticParseInfo);
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ package com.tencent.supersonic.headless.chat.parser;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
|
||||
/**
|
||||
* A semantic parser understands user queries and generates semantic query statement.
|
||||
* SuperSonic leverages a combination of rule-based and LLM-based parsers,
|
||||
* each of which deals with specific scenarios.
|
||||
* A semantic parser understands user queries and generates semantic query statement. SuperSonic
|
||||
* leverages a combination of rule-based and LLM-based parsers, each of which deals with specific
|
||||
* scenarios.
|
||||
*/
|
||||
public interface SemanticParser {
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
|
||||
import java.util.Set;
|
||||
@@ -8,5 +7,4 @@ import java.util.Set;
|
||||
public interface DataSetResolver {
|
||||
|
||||
Long resolve(ChatQueryContext chatQueryContext, Set<Long> restrictiveModels);
|
||||
|
||||
}
|
||||
|
||||
@@ -3,10 +3,11 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -21,9 +22,9 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class HeuristicDataSetResolver implements DataSetResolver {
|
||||
|
||||
protected static Long selectDataSetBySchemaElementMatchScore(Map<Long, SemanticQuery> dataSetQueryModes,
|
||||
SchemaMapInfo schemaMap) {
|
||||
//dataSet count priority
|
||||
protected static Long selectDataSetBySchemaElementMatchScore(
|
||||
Map<Long, SemanticQuery> dataSetQueryModes, SchemaMapInfo schemaMap) {
|
||||
// dataSet count priority
|
||||
Long dataSetIdByDataSetCount = getDataSetIdByMatchDataSetScore(schemaMap);
|
||||
if (Objects.nonNull(dataSetIdByDataSetCount)) {
|
||||
log.info("selectDataSet by dataSet count:{}", dataSetIdByDataSetCount);
|
||||
@@ -38,16 +39,24 @@ public class HeuristicDataSetResolver implements DataSetResolver {
|
||||
return dataSetSelect;
|
||||
}
|
||||
} else {
|
||||
Entry<Long, DataSetMatchResult> maxDataSet = dataSetTypeMap.entrySet().stream()
|
||||
.filter(entry -> dataSetQueryModes.containsKey(entry.getKey()))
|
||||
.sorted((o1, o2) -> {
|
||||
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
||||
if (difference == 0) {
|
||||
return (int) ((o2.getValue().getMaxSimilarity()
|
||||
- o1.getValue().getMaxSimilarity()) * 100);
|
||||
}
|
||||
return difference;
|
||||
}).findFirst().orElse(null);
|
||||
Entry<Long, DataSetMatchResult> maxDataSet =
|
||||
dataSetTypeMap.entrySet().stream()
|
||||
.filter(entry -> dataSetQueryModes.containsKey(entry.getKey()))
|
||||
.sorted(
|
||||
(o1, o2) -> {
|
||||
int difference =
|
||||
o2.getValue().getCount() - o1.getValue().getCount();
|
||||
if (difference == 0) {
|
||||
return (int)
|
||||
((o2.getValue().getMaxSimilarity()
|
||||
- o1.getValue()
|
||||
.getMaxSimilarity())
|
||||
* 100);
|
||||
}
|
||||
return difference;
|
||||
})
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
if (maxDataSet != null) {
|
||||
log.info("selectDataSet with multiple DataSets [{}]", maxDataSet.getKey());
|
||||
return maxDataSet.getKey();
|
||||
@@ -57,26 +66,40 @@ public class HeuristicDataSetResolver implements DataSetResolver {
|
||||
}
|
||||
|
||||
private static Long getDataSetIdByMatchDataSetScore(SchemaMapInfo schemaMap) {
|
||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = schemaMap.getDataSetElementMatches();
|
||||
// calculate dataSet match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
||||
schemaMap.getDataSetElementMatches();
|
||||
// calculate dataSet match score, matched element gets 1.0 point, and inherit element gets
|
||||
// 0.5 point
|
||||
Map<Long, Double> dataSetIdToDataSetScore = new HashMap<>();
|
||||
if (Objects.nonNull(dataSetElementMatches)) {
|
||||
for (Entry<Long, List<SchemaElementMatch>> dataSetElementMatch : dataSetElementMatches.entrySet()) {
|
||||
for (Entry<Long, List<SchemaElementMatch>> dataSetElementMatch :
|
||||
dataSetElementMatches.entrySet()) {
|
||||
Long dataSetId = dataSetElementMatch.getKey();
|
||||
List<Double> dataSetMatchesScore = dataSetElementMatch.getValue().stream()
|
||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||
.filter(elementMatch -> SchemaElementType.DATASET.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||
List<Double> dataSetMatchesScore =
|
||||
dataSetElementMatch.getValue().stream()
|
||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||
.filter(
|
||||
elementMatch ->
|
||||
SchemaElementType.DATASET.equals(
|
||||
elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(dataSetMatchesScore)) {
|
||||
// get sum of dataSet match score
|
||||
double score = dataSetMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||
double score =
|
||||
dataSetMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||
dataSetIdToDataSetScore.put(dataSetId, score);
|
||||
}
|
||||
}
|
||||
Entry<Long, Double> maxDataSetScore = dataSetIdToDataSetScore.entrySet().stream()
|
||||
.max(Comparator.comparingDouble(Entry::getValue)).orElse(null);
|
||||
log.info("maxDataSetCount:{},dataSetIdToDataSetCount:{}", maxDataSetScore, dataSetIdToDataSetScore);
|
||||
Entry<Long, Double> maxDataSetScore =
|
||||
dataSetIdToDataSetScore.entrySet().stream()
|
||||
.max(Comparator.comparingDouble(Entry::getValue))
|
||||
.orElse(null);
|
||||
log.info(
|
||||
"maxDataSetCount:{},dataSetIdToDataSetCount:{}",
|
||||
maxDataSetScore,
|
||||
dataSetIdToDataSetScore);
|
||||
if (Objects.nonNull(maxDataSetScore)) {
|
||||
return maxDataSetScore.getKey();
|
||||
}
|
||||
@@ -86,8 +109,10 @@ public class HeuristicDataSetResolver implements DataSetResolver {
|
||||
|
||||
public static Map<Long, DataSetMatchResult> getDataSetTypeMap(SchemaMapInfo schemaMap) {
|
||||
Map<Long, DataSetMatchResult> dataSetCount = new HashMap<>();
|
||||
for (Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getDataSetElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
||||
for (Entry<Long, List<SchemaElementMatch>> entry :
|
||||
schemaMap.getDataSetElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> schemaElementMatches =
|
||||
schemaMap.getMatchedElements(entry.getKey());
|
||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||
if (!dataSetCount.containsKey(entry.getKey())) {
|
||||
dataSetCount.put(entry.getKey(), new DataSetMatchResult());
|
||||
@@ -95,17 +120,23 @@ public class HeuristicDataSetResolver implements DataSetResolver {
|
||||
DataSetMatchResult dataSetMatchResult = dataSetCount.get(entry.getKey());
|
||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||
schemaElementMatches.stream()
|
||||
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
||||
schemaElementMatch.getElement().getType()));
|
||||
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
||||
.sorted((o1, o2) ->
|
||||
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
||||
).findFirst().orElse(null);
|
||||
.forEach(
|
||||
schemaElementMatch ->
|
||||
schemaElementTypes.add(
|
||||
schemaElementMatch.getElement().getType()));
|
||||
SchemaElementMatch schemaElementMatchMax =
|
||||
schemaElementMatches.stream()
|
||||
.sorted(
|
||||
(o1, o2) ->
|
||||
((int)
|
||||
((o2.getSimilarity() - o1.getSimilarity())
|
||||
* 100)))
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
if (schemaElementMatchMax != null) {
|
||||
dataSetMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||
}
|
||||
dataSetMatchResult.setCount(schemaElementTypes.size());
|
||||
|
||||
}
|
||||
}
|
||||
return dataSetCount;
|
||||
@@ -126,5 +157,4 @@ public class HeuristicDataSetResolver implements DataSetResolver {
|
||||
}
|
||||
return selectDataSetBySchemaElementMatchScore(dataSetQueryModes, mapInfo);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
@@ -16,6 +13,13 @@ import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
@@ -25,19 +29,15 @@ import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LLMRequestService {
|
||||
|
||||
@Autowired
|
||||
private ParserConfig parserConfig;
|
||||
@Autowired private ParserConfig parserConfig;
|
||||
|
||||
public boolean isSkip(ChatQueryContext queryCtx) {
|
||||
if (!queryCtx.getText2SQLType().enableLLM()) {
|
||||
@@ -79,7 +79,8 @@ public class LLMRequestService {
|
||||
llmReq.setPriorExts(priorKnowledge);
|
||||
|
||||
List<LLMReq.ElementValue> linking = new ArrayList<>();
|
||||
boolean linkingValueEnabled = Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
|
||||
boolean linkingValueEnabled =
|
||||
Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
|
||||
|
||||
if (linkingValueEnabled) {
|
||||
linking.addAll(linkingValues);
|
||||
@@ -87,7 +88,8 @@ public class LLMRequestService {
|
||||
llmReq.setLinking(linking);
|
||||
|
||||
llmReq.setCurrentDate(DateUtils.getBeforeDate(0));
|
||||
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||
llmReq.setSqlGenType(
|
||||
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||
llmReq.setModelConfig(queryCtx.getModelConfig());
|
||||
llmReq.setPromptConfig(queryCtx.getPromptConfig());
|
||||
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
|
||||
@@ -105,21 +107,27 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
protected List<LLMReq.Term> getTerms(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.TERM.equals(elementType);
|
||||
}).map(schemaElementMatch -> {
|
||||
LLMReq.Term term = new LLMReq.Term();
|
||||
term.setName(schemaElementMatch.getElement().getName());
|
||||
term.setDescription(schemaElementMatch.getElement().getDescription());
|
||||
term.setAlias(schemaElementMatch.getElement().getAlias());
|
||||
return term;
|
||||
}).collect(Collectors.toList());
|
||||
.filter(
|
||||
schemaElementMatch -> {
|
||||
SchemaElementType elementType =
|
||||
schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.TERM.equals(elementType);
|
||||
})
|
||||
.map(
|
||||
schemaElementMatch -> {
|
||||
LLMReq.Term term = new LLMReq.Term();
|
||||
term.setName(schemaElementMatch.getElement().getName());
|
||||
term.setDescription(schemaElementMatch.getElement().getDescription());
|
||||
term.setAlias(schemaElementMatch.getElement().getAlias());
|
||||
return term;
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private String getPriorKnowledge(ChatQueryContext queryContext, LLMReq.LLMSchema llmSchema) {
|
||||
@@ -137,23 +145,33 @@ public class LLMRequestService {
|
||||
private Map<String, String> getFieldNameToDataFormatTypeMap(SemanticSchema semanticSchema) {
|
||||
return semanticSchema.getMetrics().stream()
|
||||
.filter(metric -> Objects.nonNull(metric.getDataFormatType()))
|
||||
.flatMap(metric -> {
|
||||
Set<Pair<String, String>> fieldFormatPairs = new HashSet<>();
|
||||
String dataFormatType = metric.getDataFormatType();
|
||||
fieldFormatPairs.add(Pair.of(metric.getName(), dataFormatType));
|
||||
List<String> aliasList = metric.getAlias();
|
||||
if (!CollectionUtils.isEmpty(aliasList)) {
|
||||
aliasList.forEach(alias -> fieldFormatPairs.add(Pair.of(alias, dataFormatType)));
|
||||
}
|
||||
return fieldFormatPairs.stream();
|
||||
})
|
||||
.collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (existing, replacement) -> existing));
|
||||
.flatMap(
|
||||
metric -> {
|
||||
Set<Pair<String, String>> fieldFormatPairs = new HashSet<>();
|
||||
String dataFormatType = metric.getDataFormatType();
|
||||
fieldFormatPairs.add(Pair.of(metric.getName(), dataFormatType));
|
||||
List<String> aliasList = metric.getAlias();
|
||||
if (!CollectionUtils.isEmpty(aliasList)) {
|
||||
aliasList.forEach(
|
||||
alias ->
|
||||
fieldFormatPairs.add(
|
||||
Pair.of(alias, dataFormatType)));
|
||||
}
|
||||
return fieldFormatPairs.stream();
|
||||
})
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
Pair::getLeft,
|
||||
Pair::getRight,
|
||||
(existing, replacement) -> existing));
|
||||
}
|
||||
|
||||
private void appendMetricPriorKnowledge(LLMReq.LLMSchema llmSchema,
|
||||
private void appendMetricPriorKnowledge(
|
||||
LLMReq.LLMSchema llmSchema,
|
||||
StringBuilder priorKnowledgeBuilder,
|
||||
SemanticSchema semanticSchema) {
|
||||
Map<String, String> fieldNameToDataFormatType = getFieldNameToDataFormatTypeMap(semanticSchema);
|
||||
Map<String, String> fieldNameToDataFormatType =
|
||||
getFieldNameToDataFormatTypeMap(semanticSchema);
|
||||
|
||||
for (SchemaElement schemaElement : llmSchema.getMetrics()) {
|
||||
String fieldName = schemaElement.getName();
|
||||
@@ -168,14 +186,15 @@ public class LLMRequestService {
|
||||
private Map<String, String> getFieldNameToDateFormatMap(SemanticSchema semanticSchema) {
|
||||
return semanticSchema.getDimensions().stream()
|
||||
.filter(dimension -> StringUtils.isNotBlank(dimension.getTimeFormat()))
|
||||
.collect(Collectors.toMap(
|
||||
SchemaElement::getName,
|
||||
value -> Optional.ofNullable(value.getTimeFormat()).orElse(""),
|
||||
(k1, k2) -> k1)
|
||||
);
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
SchemaElement::getName,
|
||||
value -> Optional.ofNullable(value.getTimeFormat()).orElse(""),
|
||||
(k1, k2) -> k1));
|
||||
}
|
||||
|
||||
private void appendDimensionPriorKnowledge(LLMReq.LLMSchema llmSchema,
|
||||
private void appendDimensionPriorKnowledge(
|
||||
LLMReq.LLMSchema llmSchema,
|
||||
StringBuilder priorKnowledgeBuilder,
|
||||
SemanticSchema semanticSchema) {
|
||||
Map<String, String> fieldNameToDateFormat = getFieldNameToDateFormatMap(semanticSchema);
|
||||
@@ -187,7 +206,8 @@ public class LLMRequestService {
|
||||
continue;
|
||||
}
|
||||
if (schemaElement.containsPartitionTime()) {
|
||||
priorKnowledgeBuilder.append(String.format("%s 是分区时间且格式是%s", fieldName, timeFormat));
|
||||
priorKnowledgeBuilder.append(
|
||||
String.format("%s 是分区时间且格式是%s", fieldName, timeFormat));
|
||||
} else {
|
||||
priorKnowledgeBuilder.append(String.format("%s 的时间格式是%s", fieldName, timeFormat));
|
||||
}
|
||||
@@ -195,50 +215,66 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
public List<LLMReq.ElementValue> getValues(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
Set<LLMReq.ElementValue> valueMatches = matchedElements
|
||||
.stream()
|
||||
.filter(elementMatch -> !elementMatch.isInherited())
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType type = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
|
||||
})
|
||||
.map(elementMatch -> {
|
||||
LLMReq.ElementValue elementValue = new LLMReq.ElementValue();
|
||||
elementValue.setFieldName(elementMatch.getElement().getName());
|
||||
elementValue.setFieldValue(elementMatch.getWord());
|
||||
return elementValue;
|
||||
}).collect(Collectors.toSet());
|
||||
Set<LLMReq.ElementValue> valueMatches =
|
||||
matchedElements.stream()
|
||||
.filter(elementMatch -> !elementMatch.isInherited())
|
||||
.filter(
|
||||
schemaElementMatch -> {
|
||||
SchemaElementType type =
|
||||
schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.VALUE.equals(type)
|
||||
|| SchemaElementType.ID.equals(type);
|
||||
})
|
||||
.map(
|
||||
elementMatch -> {
|
||||
LLMReq.ElementValue elementValue = new LLMReq.ElementValue();
|
||||
elementValue.setFieldName(elementMatch.getElement().getName());
|
||||
elementValue.setFieldValue(elementMatch.getWord());
|
||||
return elementValue;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
return new ArrayList<>(valueMatches);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMatchedMetrics(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
List<SchemaElement> schemaElements = matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.METRIC.equals(elementType);
|
||||
})
|
||||
.map(schemaElementMatch -> {
|
||||
return schemaElementMatch.getElement();
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
List<SchemaElement> schemaElements =
|
||||
matchedElements.stream()
|
||||
.filter(
|
||||
schemaElementMatch -> {
|
||||
SchemaElementType elementType =
|
||||
schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.METRIC.equals(elementType);
|
||||
})
|
||||
.map(
|
||||
schemaElementMatch -> {
|
||||
return schemaElementMatch.getElement();
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
return schemaElements;
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMatchedDimensions(ChatQueryContext queryCtx, Long dataSetId) {
|
||||
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
Set<SchemaElement> dimensionElements = matchedElements.stream()
|
||||
.filter(element -> SchemaElementType.DIMENSION.equals(element.getElement().getType()))
|
||||
.map(SchemaElementMatch::getElement)
|
||||
.collect(Collectors.toSet());
|
||||
List<SchemaElementMatch> matchedElements =
|
||||
queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
Set<SchemaElement> dimensionElements =
|
||||
matchedElements.stream()
|
||||
.filter(
|
||||
element ->
|
||||
SchemaElementType.DIMENSION.equals(
|
||||
element.getElement().getType()))
|
||||
.map(SchemaElementMatch::getElement)
|
||||
.collect(Collectors.toSet());
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
if (semanticSchema == null || semanticSchema.getDataSetSchemaMap() == null) {
|
||||
return new ArrayList<>(dimensionElements);
|
||||
|
||||
@@ -1,20 +1,20 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import java.util.ArrayList;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@@ -23,25 +23,28 @@ import java.util.Objects;
|
||||
@Service
|
||||
public class LLMResponseService {
|
||||
|
||||
public SemanticParseInfo addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult,
|
||||
String s2SQL, Double weight) {
|
||||
public SemanticParseInfo addParseInfo(
|
||||
ChatQueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) {
|
||||
if (Objects.isNull(weight)) {
|
||||
weight = 0D;
|
||||
}
|
||||
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(LLMSqlQuery.QUERY_MODE);
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
parseInfo.setDataSet(queryCtx.getSemanticSchema().getDataSet(parseResult.getDataSetId()));
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getDataSetId()));
|
||||
parseInfo
|
||||
.getElementMatches()
|
||||
.addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getDataSetId()));
|
||||
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, parseResult);
|
||||
properties.put("type", "internal");
|
||||
Text2SQLExemplar exemplar = Text2SQLExemplar.builder()
|
||||
.question(queryCtx.getQueryText())
|
||||
.sideInfo(parseResult.getLlmResp().getSideInfo())
|
||||
.dbSchema(parseResult.getLlmResp().getSchema())
|
||||
.sql(parseResult.getLlmResp().getSqlOutput())
|
||||
.build();
|
||||
Text2SQLExemplar exemplar =
|
||||
Text2SQLExemplar.builder()
|
||||
.question(queryCtx.getQueryText())
|
||||
.sideInfo(parseResult.getLlmResp().getSideInfo())
|
||||
.dbSchema(parseResult.getLlmResp().getSchema())
|
||||
.sql(parseResult.getLlmResp().getSqlOutput())
|
||||
.build();
|
||||
properties.put(Text2SQLExemplar.PROPERTY_KEY, exemplar);
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||
@@ -60,7 +63,8 @@ public class LLMResponseService {
|
||||
Map<String, LLMSqlResp> result = new HashMap<>();
|
||||
for (Map.Entry<String, LLMSqlResp> entry : sqlRespMap.entrySet()) {
|
||||
String key = entry.getKey();
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlValidHelper.equals(existKey, key))) {
|
||||
if (result.keySet().stream()
|
||||
.anyMatch(existKey -> SqlValidHelper.equals(existKey, key))) {
|
||||
continue;
|
||||
}
|
||||
if (!SqlValidHelper.isValidSQL(key)) {
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
|
||||
/**
|
||||
* LLMSqlParser uses large language model to understand query semantics and
|
||||
* generate S2SQL statements to be executed by the semantic query engine.
|
||||
* LLMSqlParser uses large language model to understand query semantics and generate S2SQL
|
||||
* statements to be executed by the semantic query engine.
|
||||
*/
|
||||
@Slf4j
|
||||
public class LLMSqlParser implements SemanticParser {
|
||||
@@ -25,18 +25,18 @@ public class LLMSqlParser implements SemanticParser {
|
||||
public void parse(ChatQueryContext queryCtx) {
|
||||
try {
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
//1.determine whether to skip this parser.
|
||||
// 1.determine whether to skip this parser.
|
||||
if (requestService.isSkip(queryCtx)) {
|
||||
return;
|
||||
}
|
||||
//2.get dataSetId from queryCtx and chatCtx.
|
||||
// 2.get dataSetId from queryCtx and chatCtx.
|
||||
Long dataSetId = requestService.getDataSetId(queryCtx);
|
||||
if (dataSetId == null) {
|
||||
return;
|
||||
}
|
||||
log.info("try generating query statement for dataSetId:{}", dataSetId);
|
||||
|
||||
//3.invoke LLM service to do parsing.
|
||||
// 3.invoke LLM service to do parsing.
|
||||
tryParse(queryCtx, dataSetId);
|
||||
} catch (Exception e) {
|
||||
log.error("failed to parse query:", e);
|
||||
@@ -58,11 +58,16 @@ public class LLMSqlParser implements SemanticParser {
|
||||
try {
|
||||
LLMResp llmResp = requestService.runText2SQL(llmReq);
|
||||
if (Objects.nonNull(llmResp)) {
|
||||
//deduplicate the S2SQL result list and build parserInfo
|
||||
// deduplicate the S2SQL result list and build parserInfo
|
||||
sqlRespMap = responseService.getDeduplicationSqlResp(currentRetry, llmResp);
|
||||
if (MapUtils.isNotEmpty(sqlRespMap)) {
|
||||
parseResult = ParseResult.builder().dataSetId(dataSetId).llmReq(llmReq)
|
||||
.llmResp(llmResp).linkingValues(llmReq.getLinking()).build();
|
||||
parseResult =
|
||||
ParseResult.builder()
|
||||
.dataSetId(dataSetId)
|
||||
.llmReq(llmReq)
|
||||
.llmResp(llmResp)
|
||||
.linkingValues(llmReq.getLinking())
|
||||
.build();
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -80,5 +85,4 @@ public class LLMSqlParser implements SemanticParser {
|
||||
responseService.addParseInfo(queryCtx, parseResult, sql, sqlWeight);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -20,35 +20,35 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "\n#Role: You are a data analyst experienced in SQL languages."
|
||||
+ "#Task: You will be provided with a natural language question asked by users,"
|
||||
+ "please convert it to a SQL query so that relevant data could be returned "
|
||||
+ "by executing the SQL query against underlying database."
|
||||
+ "\n#Rules:"
|
||||
+ "1.ALWAYS generate column specified in the `Schema`, DO NOT hallucinate."
|
||||
+ "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
||||
+ "3.ALWAYS calculate the absolute date range by yourself."
|
||||
+ "4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it if needed."
|
||||
+ "6.ONLY respond with the converted SQL statement."
|
||||
+ "\n#Exemplars:\n{{exemplar}}"
|
||||
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}},SQL:";
|
||||
private static final String INSTRUCTION =
|
||||
""
|
||||
+ "\n#Role: You are a data analyst experienced in SQL languages."
|
||||
+ "#Task: You will be provided with a natural language question asked by users,"
|
||||
+ "please convert it to a SQL query so that relevant data could be returned "
|
||||
+ "by executing the SQL query against underlying database."
|
||||
+ "\n#Rules:"
|
||||
+ "1.ALWAYS generate column specified in the `Schema`, DO NOT hallucinate."
|
||||
+ "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
||||
+ "3.ALWAYS calculate the absolute date range by yourself."
|
||||
+ "4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||
+ "5.DO NOT miss the AGGREGATE operator of metrics, always add it if needed."
|
||||
+ "6.ONLY respond with the converted SQL statement."
|
||||
+ "\n#Exemplars:\n{{exemplar}}"
|
||||
+ "Question:{{question}},Schema:{{schema}},SideInfo:{{information}},SQL:";
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
//1.recall exemplars
|
||||
// 1.recall exemplars
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
|
||||
List<List<Text2SQLExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||
|
||||
//2.generate sql generation prompt for each self-consistency inference
|
||||
// 2.generate sql generation prompt for each self-consistency inference
|
||||
Map<Prompt, List<Text2SQLExemplar>> prompt2Exemplar = new HashMap<>();
|
||||
for (List<Text2SQLExemplar> exemplars : exemplarsList) {
|
||||
llmReq.setDynamicExemplars(exemplars);
|
||||
@@ -56,25 +56,36 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
prompt2Exemplar.put(prompt, exemplars);
|
||||
}
|
||||
|
||||
//3.perform multiple self-consistency inferences parallelly
|
||||
// 3.perform multiple self-consistency inferences parallelly
|
||||
Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>();
|
||||
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String sqlOutput = StringUtils.normalizeSpace(response.content().text());
|
||||
// replace ```
|
||||
String sqlOutputFormat = sqlOutput.replaceAll("(?s)```sql\\s*(.*?)\\s*```", "$1").trim();
|
||||
output2Prompt.put(sqlOutputFormat, prompt);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", sqlOutputFormat);
|
||||
}
|
||||
);
|
||||
prompt2Exemplar
|
||||
.keySet()
|
||||
.parallelStream()
|
||||
.forEach(
|
||||
prompt -> {
|
||||
keyPipelineLog.info(
|
||||
"OnePassSCSqlGenStrategy reqPrompt:\n{}",
|
||||
prompt.toUserMessage());
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
getChatLanguageModel(llmReq.getModelConfig());
|
||||
Response<AiMessage> response =
|
||||
chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String sqlOutput =
|
||||
StringUtils.normalizeSpace(response.content().text());
|
||||
// replace ```
|
||||
String sqlOutputFormat =
|
||||
sqlOutput.replaceAll("(?s)```sql\\s*(.*?)\\s*```", "$1").trim();
|
||||
output2Prompt.put(sqlOutputFormat, prompt);
|
||||
keyPipelineLog.info(
|
||||
"OnePassSCSqlGenStrategy modelResp:\n{}", sqlOutputFormat);
|
||||
});
|
||||
|
||||
//4.format response.
|
||||
Pair<String, Map<String, Double>> sqlMapPair = ResponseHelper.selfConsistencyVote(
|
||||
Lists.newArrayList(output2Prompt.keySet()));
|
||||
// 4.format response.
|
||||
Pair<String, Map<String, Double>> sqlMapPair =
|
||||
ResponseHelper.selfConsistencyVote(Lists.newArrayList(output2Prompt.keySet()));
|
||||
llmResp.setSqlOutput(sqlMapPair.getLeft());
|
||||
List<Text2SQLExemplar> usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
|
||||
List<Text2SQLExemplar> usedExemplars =
|
||||
prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
|
||||
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(usedExemplars, sqlMapPair.getRight()));
|
||||
|
||||
return llmResp;
|
||||
@@ -83,9 +94,13 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) {
|
||||
StringBuilder exemplars = new StringBuilder();
|
||||
for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) {
|
||||
String exemplarStr = String.format("Question:%s,Schema:%s,SideInfo:%s,SQL:%s\n",
|
||||
exemplar.getQuestion(), exemplar.getDbSchema(),
|
||||
exemplar.getSideInfo(), exemplar.getSql());
|
||||
String exemplarStr =
|
||||
String.format(
|
||||
"Question:%s,Schema:%s,SideInfo:%s,SQL:%s\n",
|
||||
exemplar.getQuestion(),
|
||||
exemplar.getDbSchema(),
|
||||
exemplar.getSideInfo(),
|
||||
exemplar.getSql());
|
||||
exemplars.append(exemplarStr);
|
||||
}
|
||||
String dataSemantics = promptHelper.buildSchemaStr(llmReq);
|
||||
@@ -110,6 +125,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_SELF_CONSISTENCY, this);
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(
|
||||
LLMReq.SqlGenType.ONE_PASS_SELF_CONSISTENCY, this);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,21 +23,23 @@ import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_SE
|
||||
@Slf4j
|
||||
public class PromptHelper {
|
||||
|
||||
@Autowired
|
||||
private ParserConfig parserConfig;
|
||||
@Autowired private ParserConfig parserConfig;
|
||||
|
||||
@Autowired
|
||||
private ExemplarService exemplarService;
|
||||
@Autowired private ExemplarService exemplarService;
|
||||
|
||||
public List<List<Text2SQLExemplar>> getFewShotExemplars(LLMReq llmReq) {
|
||||
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||
int exemplarRecallNumber =
|
||||
Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||
int selfConsistencyNumber =
|
||||
Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||
|
||||
List<Text2SQLExemplar> exemplars = Lists.newArrayList();
|
||||
llmReq.getDynamicExemplars().stream().forEach(e -> {
|
||||
exemplars.add(e);
|
||||
});
|
||||
llmReq.getDynamicExemplars().stream()
|
||||
.forEach(
|
||||
e -> {
|
||||
exemplars.add(e);
|
||||
});
|
||||
|
||||
int recallSize = exemplarRecallNumber - llmReq.getDynamicExemplars().size();
|
||||
if (recallSize > 0) {
|
||||
@@ -76,73 +78,83 @@ public class PromptHelper {
|
||||
String tableStr = llmReq.getSchema().getDataSetName();
|
||||
|
||||
List<String> metrics = Lists.newArrayList();
|
||||
llmReq.getSchema().getMetrics().stream().forEach(
|
||||
metric -> {
|
||||
StringBuilder metricStr = new StringBuilder();
|
||||
metricStr.append("<");
|
||||
metricStr.append(metric.getName());
|
||||
if (!CollectionUtils.isEmpty(metric.getAlias())) {
|
||||
StringBuilder alias = new StringBuilder();
|
||||
metric.getAlias().stream().forEach(a -> alias.append(a + ","));
|
||||
metricStr.append(" ALIAS '" + alias + "'");
|
||||
}
|
||||
if (StringUtils.isNotEmpty(metric.getDescription())) {
|
||||
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
|
||||
}
|
||||
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
|
||||
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
|
||||
}
|
||||
metricStr.append(">");
|
||||
metrics.add(metricStr.toString());
|
||||
}
|
||||
);
|
||||
llmReq.getSchema().getMetrics().stream()
|
||||
.forEach(
|
||||
metric -> {
|
||||
StringBuilder metricStr = new StringBuilder();
|
||||
metricStr.append("<");
|
||||
metricStr.append(metric.getName());
|
||||
if (!CollectionUtils.isEmpty(metric.getAlias())) {
|
||||
StringBuilder alias = new StringBuilder();
|
||||
metric.getAlias().stream().forEach(a -> alias.append(a + ","));
|
||||
metricStr.append(" ALIAS '" + alias + "'");
|
||||
}
|
||||
if (StringUtils.isNotEmpty(metric.getDescription())) {
|
||||
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
|
||||
}
|
||||
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
|
||||
metricStr.append(
|
||||
" AGGREGATE '"
|
||||
+ metric.getDefaultAgg().toUpperCase()
|
||||
+ "'");
|
||||
}
|
||||
metricStr.append(">");
|
||||
metrics.add(metricStr.toString());
|
||||
});
|
||||
|
||||
List<String> dimensions = Lists.newArrayList();
|
||||
llmReq.getSchema().getDimensions().stream().forEach(
|
||||
dimension -> {
|
||||
StringBuilder dimensionStr = new StringBuilder();
|
||||
dimensionStr.append("<");
|
||||
dimensionStr.append(dimension.getName());
|
||||
if (!CollectionUtils.isEmpty(dimension.getAlias())) {
|
||||
StringBuilder alias = new StringBuilder();
|
||||
dimension.getAlias().stream().forEach(a -> alias.append(a + ","));
|
||||
dimensionStr.append(" ALIAS '" + alias + "'");
|
||||
}
|
||||
if (StringUtils.isNotEmpty(dimension.getDescription())) {
|
||||
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
|
||||
}
|
||||
dimensionStr.append(">");
|
||||
dimensions.add(dimensionStr.toString());
|
||||
}
|
||||
);
|
||||
llmReq.getSchema().getDimensions().stream()
|
||||
.forEach(
|
||||
dimension -> {
|
||||
StringBuilder dimensionStr = new StringBuilder();
|
||||
dimensionStr.append("<");
|
||||
dimensionStr.append(dimension.getName());
|
||||
if (!CollectionUtils.isEmpty(dimension.getAlias())) {
|
||||
StringBuilder alias = new StringBuilder();
|
||||
dimension.getAlias().stream().forEach(a -> alias.append(a + ","));
|
||||
dimensionStr.append(" ALIAS '" + alias + "'");
|
||||
}
|
||||
if (StringUtils.isNotEmpty(dimension.getDescription())) {
|
||||
dimensionStr.append(
|
||||
" COMMENT '" + dimension.getDescription() + "'");
|
||||
}
|
||||
dimensionStr.append(">");
|
||||
dimensions.add(dimensionStr.toString());
|
||||
});
|
||||
|
||||
List<String> values = Lists.newArrayList();
|
||||
llmReq.getLinking().stream().forEach(
|
||||
value -> {
|
||||
StringBuilder valueStr = new StringBuilder();
|
||||
String fieldName = value.getFieldName();
|
||||
String fieldValue = value.getFieldValue();
|
||||
valueStr.append(String.format("<%s='%s'>", fieldName, fieldValue));
|
||||
values.add(valueStr.toString());
|
||||
}
|
||||
);
|
||||
llmReq.getLinking().stream()
|
||||
.forEach(
|
||||
value -> {
|
||||
StringBuilder valueStr = new StringBuilder();
|
||||
String fieldName = value.getFieldName();
|
||||
String fieldValue = value.getFieldValue();
|
||||
valueStr.append(String.format("<%s='%s'>", fieldName, fieldValue));
|
||||
values.add(valueStr.toString());
|
||||
});
|
||||
|
||||
String template = "Table=[%s], Metrics=[%s], Dimensions=[%s], Values=[%s]";
|
||||
return String.format(template, tableStr, String.join(",", metrics),
|
||||
String.join(",", dimensions), String.join(",", values));
|
||||
return String.format(
|
||||
template,
|
||||
tableStr,
|
||||
String.join(",", metrics),
|
||||
String.join(",", dimensions),
|
||||
String.join(",", values));
|
||||
}
|
||||
|
||||
private String buildTermStr(LLMReq llmReq) {
|
||||
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
|
||||
List<String> termStr = Lists.newArrayList();
|
||||
terms.stream().forEach(
|
||||
term -> {
|
||||
StringBuilder termsDesc = new StringBuilder();
|
||||
String description = term.getDescription();
|
||||
termsDesc.append(String.format("<%s COMMENT '%s'>", term.getName(), description));
|
||||
termStr.add(termsDesc.toString());
|
||||
}
|
||||
);
|
||||
terms.stream()
|
||||
.forEach(
|
||||
term -> {
|
||||
StringBuilder termsDesc = new StringBuilder();
|
||||
String description = term.getDescription();
|
||||
termsDesc.append(
|
||||
String.format(
|
||||
"<%s COMMENT '%s'>", term.getName(), description));
|
||||
termStr.add(termsDesc.toString());
|
||||
});
|
||||
String ret = "";
|
||||
if (termStr.size() > 0) {
|
||||
ret = String.join(",", termStr);
|
||||
@@ -150,5 +162,4 @@ public class PromptHelper {
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -54,16 +54,19 @@ public class ResponseHelper {
|
||||
return Pair.of(inputMax, votePercentage);
|
||||
}
|
||||
|
||||
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Text2SQLExemplar> sqlExamples,
|
||||
Map<String, Double> sqlMap) {
|
||||
public static Map<String, LLMSqlResp> buildSqlRespMap(
|
||||
List<Text2SQLExemplar> sqlExamples, Map<String, Double> sqlMap) {
|
||||
if (sqlMap == null) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
return sqlMap.entrySet().stream()
|
||||
.collect(Collectors.toMap(
|
||||
Map.Entry::getKey,
|
||||
entry -> LLMSqlResp.builder().sqlWeight(entry.getValue()).fewShots(sqlExamples).build())
|
||||
);
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
Map.Entry::getKey,
|
||||
entry ->
|
||||
LLMSqlResp.builder()
|
||||
.sqlWeight(entry.getValue())
|
||||
.fewShots(sqlExamples)
|
||||
.build()));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -12,16 +12,15 @@ import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* SqlGenStrategy abstracts generation step so that
|
||||
* different LLM prompting strategies can be implemented.
|
||||
* SqlGenStrategy abstracts generation step so that different LLM prompting strategies can be
|
||||
* implemented.
|
||||
*/
|
||||
@Service
|
||||
public abstract class SqlGenStrategy implements InitializingBean {
|
||||
|
||||
protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
@Autowired
|
||||
protected PromptHelper promptHelper;
|
||||
@Autowired protected PromptHelper promptHelper;
|
||||
|
||||
protected ChatLanguageModel getChatLanguageModel(ChatModelConfig modelConfig) {
|
||||
return ModelProvider.getChatModel(modelConfig);
|
||||
|
||||
@@ -7,13 +7,15 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class SqlGenStrategyFactory {
|
||||
|
||||
private static Map<LLMReq.SqlGenType, SqlGenStrategy> sqlGenStrategyMap = new ConcurrentHashMap<>();
|
||||
private static Map<LLMReq.SqlGenType, SqlGenStrategy> sqlGenStrategyMap =
|
||||
new ConcurrentHashMap<>();
|
||||
|
||||
public static SqlGenStrategy get(LLMReq.SqlGenType strategyType) {
|
||||
return sqlGenStrategyMap.get(strategyType);
|
||||
}
|
||||
|
||||
public static void addSqlGenerationForFactory(LLMReq.SqlGenType strategy, SqlGenStrategy sqlGenStrategy) {
|
||||
public static void addSqlGenerationForFactory(
|
||||
LLMReq.SqlGenType strategy, SqlGenStrategy sqlGenStrategy) {
|
||||
sqlGenStrategyMap.put(strategy, sqlGenStrategy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.headless.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
@@ -20,24 +20,34 @@ import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.COUNT;
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.DISTINCT;
|
||||
|
||||
/**
|
||||
* AggregateTypeParser extracts aggregation type specified in the user query
|
||||
* based on keyword matching.
|
||||
* Currently, it supports 7 types of aggregation: max, min, sum, avg, topN,
|
||||
* distinct count, count.
|
||||
* AggregateTypeParser extracts aggregation type specified in the user query based on keyword
|
||||
* matching. Currently, it supports 7 types of aggregation: max, min, sum, avg, topN, distinct
|
||||
* count, count.
|
||||
*/
|
||||
@Slf4j
|
||||
public class AggregateTypeParser implements SemanticParser {
|
||||
|
||||
private static final Map<AggregateTypeEnum, Pattern> REGX_MAP = Stream.of(
|
||||
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.MAX, Pattern.compile("(?i)(最大值|最大|max|峰值|最高|最多)")),
|
||||
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.MIN, Pattern.compile("(?i)(最小值|最小|min|最低|最少)")),
|
||||
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.SUM, Pattern.compile("(?i)(汇总|总和|sum)")),
|
||||
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.AVG, Pattern.compile("(?i)(平均值|日均|平均|avg)")),
|
||||
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.TOPN, Pattern.compile("(?i)(top)")),
|
||||
new AbstractMap.SimpleEntry<>(DISTINCT, Pattern.compile("(?i)(uv)")),
|
||||
new AbstractMap.SimpleEntry<>(COUNT, Pattern.compile("(?i)(总数|pv)")),
|
||||
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.NONE, Pattern.compile("(?i)(明细)"))
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
|
||||
private static final Map<AggregateTypeEnum, Pattern> REGX_MAP =
|
||||
Stream.of(
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
AggregateTypeEnum.MAX,
|
||||
Pattern.compile("(?i)(最大值|最大|max|峰值|最高|最多)")),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
AggregateTypeEnum.MIN,
|
||||
Pattern.compile("(?i)(最小值|最小|min|最低|最少)")),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
AggregateTypeEnum.SUM, Pattern.compile("(?i)(汇总|总和|sum)")),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
AggregateTypeEnum.AVG, Pattern.compile("(?i)(平均值|日均|平均|avg)")),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
AggregateTypeEnum.TOPN, Pattern.compile("(?i)(top)")),
|
||||
new AbstractMap.SimpleEntry<>(DISTINCT, Pattern.compile("(?i)(uv)")),
|
||||
new AbstractMap.SimpleEntry<>(COUNT, Pattern.compile("(?i)(总数|pv)")),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
AggregateTypeEnum.NONE, Pattern.compile("(?i)(明细)")))
|
||||
.collect(
|
||||
Collectors.toMap(
|
||||
Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
@@ -53,7 +63,9 @@ public class AggregateTypeParser implements SemanticParser {
|
||||
if (StringUtils.isNotEmpty(aggregateConf.detectWord)) {
|
||||
detectWordLength = aggregateConf.detectWord.length();
|
||||
}
|
||||
semanticQuery.getParseInfo().setScore(semanticQuery.getParseInfo().getScore() + detectWordLength);
|
||||
semanticQuery
|
||||
.getParseInfo()
|
||||
.setScore(semanticQuery.getParseInfo().getScore() + detectWordLength);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +78,6 @@ public class AggregateTypeParser implements SemanticParser {
|
||||
Map<AggregateTypeEnum, Integer> aggregateCount = new HashMap<>(REGX_MAP.size());
|
||||
Map<AggregateTypeEnum, String> aggregateWord = new HashMap<>(REGX_MAP.size());
|
||||
|
||||
|
||||
for (Map.Entry<AggregateTypeEnum, Pattern> entry : REGX_MAP.entrySet()) {
|
||||
Matcher matcher = entry.getValue().matcher(queryText);
|
||||
int count = 0;
|
||||
@@ -81,8 +92,11 @@ public class AggregateTypeParser implements SemanticParser {
|
||||
}
|
||||
}
|
||||
|
||||
AggregateTypeEnum type = aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue())
|
||||
.map(entry -> entry.getKey()).orElse(AggregateTypeEnum.NONE);
|
||||
AggregateTypeEnum type =
|
||||
aggregateCount.entrySet().stream()
|
||||
.max(Map.Entry.comparingByValue())
|
||||
.map(entry -> entry.getKey())
|
||||
.orElse(AggregateTypeEnum.NONE);
|
||||
String detectWord = aggregateWord.get(type);
|
||||
return new AggregateConf(type, detectWord);
|
||||
}
|
||||
@@ -92,5 +106,4 @@ public class AggregateTypeParser implements SemanticParser {
|
||||
public AggregateTypeEnum type;
|
||||
public String detectWord;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,14 +3,14 @@ package com.tencent.supersonic.headless.chat.parser.rule;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.detail.DetailDimensionQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
@@ -24,23 +24,34 @@ import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* ContextInheritParser tries to inherit certain schema elements from context
|
||||
* so that in multi-turn conversations users don't need to mention some keyword
|
||||
* repeatedly.
|
||||
* ContextInheritParser tries to inherit certain schema elements from context so that in multi-turn
|
||||
* conversations users don't need to mention some keyword repeatedly.
|
||||
*/
|
||||
@Slf4j
|
||||
public class ContextInheritParser implements SemanticParser {
|
||||
|
||||
private static final Map<SchemaElementType, List<SchemaElementType>> MUTUAL_EXCLUSIVE_MAP = Stream.of(
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.METRIC, Arrays.asList(SchemaElementType.METRIC)),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
SchemaElementType.DIMENSION, Arrays.asList(SchemaElementType.DIMENSION, SchemaElementType.VALUE)),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
SchemaElementType.VALUE, Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY, Arrays.asList(SchemaElementType.ENTITY)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.DATASET, Arrays.asList(SchemaElementType.DATASET)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ID, Arrays.asList(SchemaElementType.ID))
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
private static final Map<SchemaElementType, List<SchemaElementType>> MUTUAL_EXCLUSIVE_MAP =
|
||||
Stream.of(
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
SchemaElementType.METRIC,
|
||||
Arrays.asList(SchemaElementType.METRIC)),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
SchemaElementType.DIMENSION,
|
||||
Arrays.asList(
|
||||
SchemaElementType.DIMENSION, SchemaElementType.VALUE)),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
SchemaElementType.VALUE,
|
||||
Arrays.asList(
|
||||
SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
SchemaElementType.ENTITY,
|
||||
Arrays.asList(SchemaElementType.ENTITY)),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
SchemaElementType.DATASET,
|
||||
Arrays.asList(SchemaElementType.DATASET)),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
SchemaElementType.ID, Arrays.asList(SchemaElementType.ID)))
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
@@ -52,14 +63,17 @@ public class ContextInheritParser implements SemanticParser {
|
||||
return;
|
||||
}
|
||||
|
||||
List<SchemaElementMatch> elementMatches = chatQueryContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
List<SchemaElementMatch> elementMatches =
|
||||
chatQueryContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
|
||||
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
||||
for (SchemaElementMatch match : chatQueryContext.getContextParseInfo().getElementMatches()) {
|
||||
for (SchemaElementMatch match :
|
||||
chatQueryContext.getContextParseInfo().getElementMatches()) {
|
||||
SchemaElementType matchType = match.getElement().getType();
|
||||
// mutual exclusive element types should not be inherited
|
||||
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(
|
||||
chatQueryContext.getContextParseInfo().getQueryMode());
|
||||
RuleSemanticQuery ruleQuery =
|
||||
QueryManager.getRuleQuery(
|
||||
chatQueryContext.getContextParseInfo().getQueryMode());
|
||||
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
|
||||
match.setInherited(true);
|
||||
matchesToInherit.add(match);
|
||||
@@ -67,17 +81,20 @@ public class ContextInheritParser implements SemanticParser {
|
||||
}
|
||||
elementMatches.addAll(matchesToInherit);
|
||||
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
|
||||
List<RuleSemanticQuery> queries =
|
||||
RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(chatQueryContext);
|
||||
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), chatQueryContext)) {
|
||||
if (existSameQuery(
|
||||
query.getParseInfo().getDataSetId(), query.getQueryMode(), chatQueryContext)) {
|
||||
continue;
|
||||
}
|
||||
chatQueryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean existSameQuery(Long dataSetId, String queryMode, ChatQueryContext chatQueryContext) {
|
||||
private boolean existSameQuery(
|
||||
Long dataSetId, String queryMode, ChatQueryContext chatQueryContext) {
|
||||
for (SemanticQuery semanticQuery : chatQueryContext.getCandidateQueries()) {
|
||||
if (semanticQuery.getQueryMode().equals(queryMode)
|
||||
&& semanticQuery.getParseInfo().getDataSetId().equals(dataSetId)) {
|
||||
@@ -87,26 +104,34 @@ public class ContextInheritParser implements SemanticParser {
|
||||
return false;
|
||||
}
|
||||
|
||||
private boolean containsTypes(List<SchemaElementMatch> matches, SchemaElementType matchType,
|
||||
private boolean containsTypes(
|
||||
List<SchemaElementMatch> matches,
|
||||
SchemaElementType matchType,
|
||||
RuleSemanticQuery ruleQuery) {
|
||||
List<SchemaElementType> types = MUTUAL_EXCLUSIVE_MAP.get(matchType);
|
||||
|
||||
return matches.stream().anyMatch(m -> {
|
||||
SchemaElementType type = m.getElement().getType();
|
||||
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
|
||||
&& !(ruleQuery instanceof MetricIdQuery)) {
|
||||
return types.contains(type);
|
||||
}
|
||||
return type.equals(matchType);
|
||||
});
|
||||
return matches.stream()
|
||||
.anyMatch(
|
||||
m -> {
|
||||
SchemaElementType type = m.getElement().getType();
|
||||
if (Objects.nonNull(ruleQuery)
|
||||
&& ruleQuery instanceof MetricSemanticQuery
|
||||
&& !(ruleQuery instanceof MetricIdQuery)) {
|
||||
return types.contains(type);
|
||||
}
|
||||
return type.equals(matchType);
|
||||
});
|
||||
}
|
||||
|
||||
protected boolean shouldInherit(ChatQueryContext chatQueryContext) {
|
||||
// if candidates only have MetricModel mode, count in context
|
||||
List<SemanticQuery> metricModelQueries = chatQueryContext.getCandidateQueries().stream()
|
||||
.filter(query -> query instanceof MetricModelQuery
|
||||
|| query instanceof DetailDimensionQuery).collect(
|
||||
Collectors.toList());
|
||||
List<SemanticQuery> metricModelQueries =
|
||||
chatQueryContext.getCandidateQueries().stream()
|
||||
.filter(
|
||||
query ->
|
||||
query instanceof MetricModelQuery
|
||||
|| query instanceof DetailDimensionQuery)
|
||||
.collect(Collectors.toList());
|
||||
return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size();
|
||||
}
|
||||
|
||||
@@ -121,5 +146,4 @@ public class ContextInheritParser implements SemanticParser {
|
||||
}
|
||||
return dataSetId;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,24 +3,23 @@ package com.tencent.supersonic.headless.chat.parser.rule;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance
|
||||
* of certain schema element types.
|
||||
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance of certain schema
|
||||
* element types.
|
||||
*/
|
||||
@Slf4j
|
||||
public class RuleSqlParser implements SemanticParser {
|
||||
|
||||
private static List<SemanticParser> auxiliaryParsers = Arrays.asList(
|
||||
new ContextInheritParser(),
|
||||
new TimeRangeParser(),
|
||||
new AggregateTypeParser()
|
||||
);
|
||||
private static List<SemanticParser> auxiliaryParsers =
|
||||
Arrays.asList(
|
||||
new ContextInheritParser(), new TimeRangeParser(), new AggregateTypeParser());
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
@@ -31,7 +30,8 @@ public class RuleSqlParser implements SemanticParser {
|
||||
// iterate all schemaElementMatches to resolve query mode
|
||||
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
|
||||
List<RuleSemanticQuery> queries =
|
||||
RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(chatQueryContext);
|
||||
chatQueryContext.getCandidateQueries().add(query);
|
||||
|
||||
@@ -23,17 +23,16 @@ import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
/**
|
||||
* TimeRangeParser extracts time range specified in the user query
|
||||
* based on keyword matching.
|
||||
* Currently, it supports two kinds of expression:
|
||||
* 1. Recent unit: 近N天/周/月/年、过去N天/周/月/年
|
||||
* 2. Concrete date: 2023年11月15日、20231115
|
||||
* TimeRangeParser extracts time range specified in the user query based on keyword matching.
|
||||
* Currently, it supports two kinds of expression: 1. Recent unit: 近N天/周/月/年、过去N天/周/月/年 2. Concrete
|
||||
* date: 2023年11月15日、20231115
|
||||
*/
|
||||
@Slf4j
|
||||
public class TimeRangeParser implements SemanticParser {
|
||||
|
||||
private static final Pattern RECENT_PATTERN_CN = Pattern.compile(
|
||||
".*(?<periodStr>(近|过去)((?<enNum>\\d+)|(?<zhNum>[一二三四五六七八九十百千万亿]+))个?(?<zhPeriod>[天周月年])).*");
|
||||
private static final Pattern RECENT_PATTERN_CN =
|
||||
Pattern.compile(
|
||||
".*(?<periodStr>(近|过去)((?<enNum>\\d+)|(?<zhNum>[一二三四五六七八九十百千万亿]+))个?(?<zhPeriod>[天周月年])).*");
|
||||
private static final Pattern DATE_PATTERN_NUMBER = Pattern.compile("(\\d{8})");
|
||||
private static final DateFormat DATE_FORMAT_NUMBER = new SimpleDateFormat("yyyyMMdd");
|
||||
private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd");
|
||||
@@ -66,11 +65,13 @@ public class TimeRangeParser implements SemanticParser {
|
||||
} else {
|
||||
SemanticParseInfo contextParseInfo = queryContext.getContextParseInfo();
|
||||
if (QueryManager.containsRuleQuery(contextParseInfo.getQueryMode())) {
|
||||
RuleSemanticQuery semanticQuery = QueryManager.createRuleQuery(contextParseInfo.getQueryMode());
|
||||
RuleSemanticQuery semanticQuery =
|
||||
QueryManager.createRuleQuery(contextParseInfo.getQueryMode());
|
||||
if (queryContext.containsPartitionDimensions(contextParseInfo.getDataSetId())) {
|
||||
contextParseInfo.setDateInfo(dateConf);
|
||||
}
|
||||
contextParseInfo.setScore(contextParseInfo.getScore() + dateConf.getDetectWord().length());
|
||||
contextParseInfo.setScore(
|
||||
contextParseInfo.getScore() + dateConf.getDetectWord().length());
|
||||
semanticQuery.setParseInfo(contextParseInfo);
|
||||
queryContext.getCandidateQueries().add(semanticQuery);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
package com.tencent.supersonic.headless.chat.query;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
@@ -53,7 +52,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql());
|
||||
}
|
||||
|
||||
protected void convertBizNameToName(DataSetSchema dataSetSchema, QueryStructReq queryStructReq) {
|
||||
protected void convertBizNameToName(
|
||||
DataSetSchema dataSetSchema, QueryStructReq queryStructReq) {
|
||||
Map<String, String> bizNameToName = dataSetSchema.getBizNameToName();
|
||||
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
|
||||
|
||||
@@ -76,12 +76,12 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
}
|
||||
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
dimensionFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
dimensionFilters.forEach(
|
||||
filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
List<Filter> metricFilters = queryStructReq.getMetricFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,8 +2,9 @@ package com.tencent.supersonic.headless.chat.query;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.detail.DetailSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -28,7 +29,6 @@ public class QueryManager {
|
||||
return createRuleQuery(queryMode);
|
||||
}
|
||||
return createLLMQuery(queryMode);
|
||||
|
||||
}
|
||||
|
||||
public static RuleSemanticQuery createRuleQuery(String queryMode) {
|
||||
@@ -83,5 +83,4 @@ public class QueryManager {
|
||||
public static List<RuleSemanticQuery> getRuleQueries() {
|
||||
return new ArrayList<>(ruleQueryMap.values());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -6,9 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
|
||||
/**
|
||||
* A semantic query executes specific type of query based on the results of semantic parsing.
|
||||
*/
|
||||
/** A semantic query executes specific type of query based on the results of semantic parsing. */
|
||||
public interface SemanticQuery {
|
||||
|
||||
String getQueryMode();
|
||||
|
||||
@@ -4,5 +4,4 @@ import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public abstract class LLMSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
public abstract class LLMSemanticQuery extends BaseSemanticQuery {}
|
||||
|
||||
@@ -29,7 +29,6 @@ public class LLMReq {
|
||||
public static class ElementValue {
|
||||
private String fieldName;
|
||||
private String fieldValue;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
@@ -44,10 +43,16 @@ public class LLMReq {
|
||||
public List<String> getFieldNameList() {
|
||||
List<String> fieldNameList = new ArrayList<>();
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
fieldNameList.addAll(metrics.stream().map(metric -> metric.getName()).collect(Collectors.toList()));
|
||||
fieldNameList.addAll(
|
||||
metrics.stream()
|
||||
.map(metric -> metric.getName())
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(dimensions)) {
|
||||
fieldNameList.addAll(dimensions.stream().map(metric -> metric.getName()).collect(Collectors.toList()));
|
||||
fieldNameList.addAll(
|
||||
dimensions.stream()
|
||||
.map(metric -> metric.getName())
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
return fieldNameList;
|
||||
}
|
||||
@@ -72,6 +77,5 @@ public class LLMReq {
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,5 +21,4 @@ public class LLMResp {
|
||||
private List<String> fields;
|
||||
|
||||
private Map<String, LLMSqlResp> sqlRespMap;
|
||||
|
||||
}
|
||||
|
||||
@@ -17,5 +17,4 @@ public class LLMSqlResp {
|
||||
private double sqlWeight;
|
||||
|
||||
private List<Text2SQLExemplar> fewShots;
|
||||
|
||||
}
|
||||
|
||||
@@ -9,8 +9,10 @@ public class QueryMatchOption {
|
||||
private RequireNumberType requireNumberType;
|
||||
private Integer requireNumber;
|
||||
|
||||
public static QueryMatchOption build(OptionType schemaElementOption,
|
||||
RequireNumberType requireNumberType, Integer requireNumber) {
|
||||
public static QueryMatchOption build(
|
||||
OptionType schemaElementOption,
|
||||
RequireNumberType requireNumberType,
|
||||
Integer requireNumber) {
|
||||
QueryMatchOption queryMatchOption = new QueryMatchOption();
|
||||
queryMatchOption.requireNumber = requireNumber;
|
||||
queryMatchOption.requireNumberType = requireNumberType;
|
||||
@@ -35,11 +37,14 @@ public class QueryMatchOption {
|
||||
}
|
||||
|
||||
public enum RequireNumberType {
|
||||
AT_MOST, AT_LEAST, EQUAL
|
||||
AT_MOST,
|
||||
AT_LEAST,
|
||||
EQUAL
|
||||
}
|
||||
|
||||
public enum OptionType {
|
||||
REQUIRED, OPTIONAL, UNUSED
|
||||
REQUIRED,
|
||||
OPTIONAL,
|
||||
UNUSED
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
@@ -21,8 +20,8 @@ public class QueryMatcher {
|
||||
private HashMap<SchemaElementType, QueryMatchOption> elementOptionMap = new HashMap<>();
|
||||
private boolean supportCompare;
|
||||
private boolean supportOrderBy;
|
||||
private List<AggregateTypeEnum> orderByTypes = Arrays.asList(AggregateTypeEnum.MAX, AggregateTypeEnum.MIN,
|
||||
AggregateTypeEnum.TOPN);
|
||||
private List<AggregateTypeEnum> orderByTypes =
|
||||
Arrays.asList(AggregateTypeEnum.MAX, AggregateTypeEnum.MIN, AggregateTypeEnum.TOPN);
|
||||
|
||||
public QueryMatcher() {
|
||||
for (SchemaElementType type : SchemaElementType.values()) {
|
||||
@@ -34,9 +33,13 @@ public class QueryMatcher {
|
||||
}
|
||||
}
|
||||
|
||||
public QueryMatcher addOption(SchemaElementType type, QueryMatchOption.OptionType option,
|
||||
QueryMatchOption.RequireNumberType requireNumberType, Integer requireNumber) {
|
||||
elementOptionMap.put(type, QueryMatchOption.build(option, requireNumberType, requireNumber));
|
||||
public QueryMatcher addOption(
|
||||
SchemaElementType type,
|
||||
QueryMatchOption.OptionType option,
|
||||
QueryMatchOption.RequireNumberType requireNumberType,
|
||||
Integer requireNumber) {
|
||||
elementOptionMap.put(
|
||||
type, QueryMatchOption.build(option, requireNumberType, requireNumber));
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -44,8 +47,7 @@ public class QueryMatcher {
|
||||
* Match schema element with current query according to the options.
|
||||
*
|
||||
* @param candidateElementMatches
|
||||
* @return a list of all matched schema elements,
|
||||
* empty list if no matches can be found
|
||||
* @return a list of all matched schema elements, empty list if no matches can be found
|
||||
*/
|
||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches) {
|
||||
List<SchemaElementMatch> elementMatches = new ArrayList<>();
|
||||
@@ -53,7 +55,8 @@ public class QueryMatcher {
|
||||
for (SchemaElementMatch schemaElementMatch : candidateElementMatches) {
|
||||
SchemaElementType schemaElementType = schemaElementMatch.getElement().getType();
|
||||
if (schemaElementTypeCount.containsKey(schemaElementType)) {
|
||||
schemaElementTypeCount.put(schemaElementType, schemaElementTypeCount.get(schemaElementType) + 1);
|
||||
schemaElementTypeCount.put(
|
||||
schemaElementType, schemaElementTypeCount.get(schemaElementType) + 1);
|
||||
} else {
|
||||
schemaElementTypeCount.put(schemaElementType, 1);
|
||||
}
|
||||
@@ -70,9 +73,12 @@ public class QueryMatcher {
|
||||
|
||||
// add element match if its element type is not declared as unused
|
||||
for (SchemaElementMatch elementMatch : candidateElementMatches) {
|
||||
QueryMatchOption elementOption = elementOptionMap.get(elementMatch.getElement().getType());
|
||||
if (Objects.nonNull(elementOption) && !elementOption.getSchemaElementOption()
|
||||
.equals(QueryMatchOption.OptionType.UNUSED)) {
|
||||
QueryMatchOption elementOption =
|
||||
elementOptionMap.get(elementMatch.getElement().getType());
|
||||
if (Objects.nonNull(elementOption)
|
||||
&& !elementOption
|
||||
.getSchemaElementOption()
|
||||
.equals(QueryMatchOption.OptionType.UNUSED)) {
|
||||
elementMatches.add(elementMatch);
|
||||
}
|
||||
}
|
||||
@@ -80,7 +86,8 @@ public class QueryMatcher {
|
||||
return elementMatches;
|
||||
}
|
||||
|
||||
private int getCount(HashMap<SchemaElementType, Integer> schemaElementTypeCount,
|
||||
private int getCount(
|
||||
HashMap<SchemaElementType, Integer> schemaElementTypeCount,
|
||||
SchemaElementType schemaElementType) {
|
||||
if (schemaElementTypeCount.containsKey(schemaElementType)) {
|
||||
return schemaElementTypeCount.get(schemaElementType);
|
||||
@@ -90,14 +97,19 @@ public class QueryMatcher {
|
||||
|
||||
private boolean isMatch(QueryMatchOption queryMatchOption, int count) {
|
||||
// check if required but empty
|
||||
if (queryMatchOption.getSchemaElementOption().equals(QueryMatchOption.OptionType.REQUIRED) && count <= 0) {
|
||||
if (queryMatchOption.getSchemaElementOption().equals(QueryMatchOption.OptionType.REQUIRED)
|
||||
&& count <= 0) {
|
||||
return false;
|
||||
}
|
||||
if (queryMatchOption.getRequireNumberType().equals(QueryMatchOption.RequireNumberType.AT_LEAST)
|
||||
if (queryMatchOption
|
||||
.getRequireNumberType()
|
||||
.equals(QueryMatchOption.RequireNumberType.AT_LEAST)
|
||||
&& count < queryMatchOption.getRequireNumber()) {
|
||||
return false;
|
||||
}
|
||||
if (queryMatchOption.getRequireNumberType().equals(QueryMatchOption.RequireNumberType.AT_MOST)
|
||||
if (queryMatchOption
|
||||
.getRequireNumberType()
|
||||
.equals(QueryMatchOption.RequireNumberType.AT_MOST)
|
||||
&& count > queryMatchOption.getRequireNumber()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
package com.tencent.supersonic.headless.chat.query.rule;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
@@ -41,8 +40,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
||||
ChatQueryContext queryCtx) {
|
||||
public List<SchemaElementMatch> match(
|
||||
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext queryCtx) {
|
||||
return queryMatcher.match(candidateElementMatches);
|
||||
}
|
||||
|
||||
@@ -68,17 +67,19 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
return chatQueryContext.containsPartitionDimensions(dataSetId);
|
||||
}
|
||||
|
||||
private void fillDateConfByInherited(SemanticParseInfo queryParseInfo, ChatQueryContext chatQueryContext) {
|
||||
private void fillDateConfByInherited(
|
||||
SemanticParseInfo queryParseInfo, ChatQueryContext chatQueryContext) {
|
||||
SemanticParseInfo contextParseInfo = chatQueryContext.getContextParseInfo();
|
||||
if (queryParseInfo.getDateInfo() != null || contextParseInfo.getDateInfo() == null
|
||||
if (queryParseInfo.getDateInfo() != null
|
||||
|| contextParseInfo.getDateInfo() == null
|
||||
|| needFillDateConf(chatQueryContext)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if ((QueryManager.isDetailQuery(queryParseInfo.getQueryMode())
|
||||
&& QueryManager.isDetailQuery(contextParseInfo.getQueryMode()))
|
||||
&& QueryManager.isDetailQuery(contextParseInfo.getQueryMode()))
|
||||
|| (QueryManager.isMetricQuery(queryParseInfo.getQueryMode())
|
||||
&& QueryManager.isMetricQuery(contextParseInfo.getQueryMode()))) {
|
||||
&& QueryManager.isMetricQuery(contextParseInfo.getQueryMode()))) {
|
||||
// inherit date info from context
|
||||
queryParseInfo.setDateInfo(contextParseInfo.getDateInfo());
|
||||
queryParseInfo.getDateInfo().setInherited(true);
|
||||
@@ -105,8 +106,11 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
|
||||
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
|
||||
Set<Long> dataSetIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
|
||||
.map(SchemaElement::getDataSetId).collect(Collectors.toSet());
|
||||
Set<Long> dataSetIds =
|
||||
parseInfo.getElementMatches().stream()
|
||||
.map(SchemaElementMatch::getElement)
|
||||
.map(SchemaElement::getDataSetId)
|
||||
.collect(Collectors.toSet());
|
||||
Long dataSetId = dataSetIds.iterator().next();
|
||||
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
|
||||
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
|
||||
@@ -117,22 +121,26 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
element.setOrder(1 - schemaMatch.getSimilarity());
|
||||
switch (element.getType()) {
|
||||
case ID:
|
||||
SchemaElement entityElement = semanticSchema.getElement(SchemaElementType.ENTITY, element.getId());
|
||||
SchemaElement entityElement =
|
||||
semanticSchema.getElement(SchemaElementType.ENTITY, element.getId());
|
||||
if (entityElement != null) {
|
||||
if (id2Values.containsKey(element.getId())) {
|
||||
id2Values.get(element.getId()).add(schemaMatch);
|
||||
} else {
|
||||
id2Values.put(element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
id2Values.put(
|
||||
element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
}
|
||||
}
|
||||
break;
|
||||
case VALUE:
|
||||
SchemaElement dimElement = semanticSchema.getElement(SchemaElementType.DIMENSION, element.getId());
|
||||
SchemaElement dimElement =
|
||||
semanticSchema.getElement(SchemaElementType.DIMENSION, element.getId());
|
||||
if (dimElement != null) {
|
||||
if (dim2Values.containsKey(element.getId())) {
|
||||
dim2Values.get(element.getId()).add(schemaMatch);
|
||||
} else {
|
||||
dim2Values.put(element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
dim2Values.put(
|
||||
element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
}
|
||||
}
|
||||
break;
|
||||
@@ -152,8 +160,11 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
addToFilters(dim2Values, parseInfo, semanticSchema, SchemaElementType.DIMENSION);
|
||||
}
|
||||
|
||||
private void addToFilters(Map<Long, List<SchemaElementMatch>> id2Values, SemanticParseInfo parseInfo,
|
||||
SemanticSchema semanticSchema, SchemaElementType entity) {
|
||||
private void addToFilters(
|
||||
Map<Long, List<SchemaElementMatch>> id2Values,
|
||||
SemanticParseInfo parseInfo,
|
||||
SemanticSchema semanticSchema,
|
||||
SchemaElementType entity) {
|
||||
if (id2Values == null || id2Values.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
@@ -170,7 +181,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
dimensionFilter.setName(dimension.getName());
|
||||
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
dimensionFilter.setElementID(schemaMatch.getElement().getId());
|
||||
parseInfo.setEntity(semanticSchema.getElement(SchemaElementType.ENTITY, entry.getKey()));
|
||||
parseInfo.setEntity(
|
||||
semanticSchema.getElement(SchemaElementType.ENTITY, entry.getKey()));
|
||||
parseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
} else {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
@@ -193,7 +205,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
public SemanticQueryReq multiStructExecute() {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
|
||||
if (parseInfo.getDataSetId() != null || StringUtils.isEmpty(queryMode)
|
||||
if (parseInfo.getDataSetId() != null
|
||||
|| StringUtils.isEmpty(queryMode)
|
||||
|| !QueryManager.containsRuleQuery(queryMode)) {
|
||||
// reach here some error may happen
|
||||
log.error("not find QueryMode");
|
||||
@@ -208,14 +221,18 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
this.parseInfo = parseInfo;
|
||||
}
|
||||
|
||||
public static List<RuleSemanticQuery> resolve(Long dataSetId, List<SchemaElementMatch> candidateElementMatches,
|
||||
ChatQueryContext chatQueryContext) {
|
||||
public static List<RuleSemanticQuery> resolve(
|
||||
Long dataSetId,
|
||||
List<SchemaElementMatch> candidateElementMatches,
|
||||
ChatQueryContext chatQueryContext) {
|
||||
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
|
||||
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
|
||||
List<SchemaElementMatch> matches = semanticQuery.match(candidateElementMatches, chatQueryContext);
|
||||
List<SchemaElementMatch> matches =
|
||||
semanticQuery.match(candidateElementMatches, chatQueryContext);
|
||||
|
||||
if (matches.size() > 0) {
|
||||
RuleSemanticQuery query = QueryManager.createRuleQuery(semanticQuery.getQueryMode());
|
||||
RuleSemanticQuery query =
|
||||
QueryManager.createRuleQuery(semanticQuery.getQueryMode());
|
||||
query.getParseInfo().getElementMatches().addAll(matches);
|
||||
matchedQueries.add(query);
|
||||
}
|
||||
@@ -230,5 +247,4 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
protected QueryMultiStructReq convertQueryMultiStruct() {
|
||||
return QueryReqBuilder.buildMultiStructReq(parseInfo);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.DIMENSION;
|
||||
@@ -26,5 +25,4 @@ public class DetailDimensionQuery extends DetailSemanticQuery {
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -24,5 +24,4 @@ public class DetailFilterQuery extends DetailListQuery {
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -20,5 +20,4 @@ public class DetailIdQuery extends DetailListQuery {
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -23,32 +23,51 @@ public abstract class DetailListQuery extends DetailSemanticQuery {
|
||||
this.addEntityDetailAndOrderByMetric(chatQueryContext, parseInfo);
|
||||
}
|
||||
|
||||
private void addEntityDetailAndOrderByMetric(ChatQueryContext chatQueryContext, SemanticParseInfo parseInfo) {
|
||||
private void addEntityDetailAndOrderByMetric(
|
||||
ChatQueryContext chatQueryContext, SemanticParseInfo parseInfo) {
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
if (Objects.isNull(dataSetId) || dataSetId <= 0L) {
|
||||
return;
|
||||
}
|
||||
DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
DataSetSchema dataSetSchema =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId);
|
||||
if (dataSetSchema != null && Objects.nonNull(dataSetSchema.getEntity())) {
|
||||
Set<SchemaElement> dimensions = new LinkedHashSet<>();
|
||||
Set<SchemaElement> metrics = new LinkedHashSet<>();
|
||||
Set<Order> orders = new LinkedHashSet<>();
|
||||
TagTypeDefaultConfig tagTypeDefaultConfig = dataSetSchema.getTagTypeDefaultConfig();
|
||||
if (tagTypeDefaultConfig != null && tagTypeDefaultConfig.getDefaultDisplayInfo() != null) {
|
||||
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) {
|
||||
metrics = tagTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds()
|
||||
.stream().map(id -> {
|
||||
SchemaElement metric = dataSetSchema.getElement(SchemaElementType.METRIC, id);
|
||||
if (metric != null) {
|
||||
orders.add(new Order(metric.getBizName(), Constants.DESC_UPPER));
|
||||
}
|
||||
return metric;
|
||||
}).filter(Objects::nonNull).collect(Collectors.toSet());
|
||||
if (tagTypeDefaultConfig != null
|
||||
&& tagTypeDefaultConfig.getDefaultDisplayInfo() != null) {
|
||||
if (CollectionUtils.isNotEmpty(
|
||||
tagTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) {
|
||||
metrics =
|
||||
tagTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds().stream()
|
||||
.map(
|
||||
id -> {
|
||||
SchemaElement metric =
|
||||
dataSetSchema.getElement(
|
||||
SchemaElementType.METRIC, id);
|
||||
if (metric != null) {
|
||||
orders.add(
|
||||
new Order(
|
||||
metric.getBizName(),
|
||||
Constants.DESC_UPPER));
|
||||
}
|
||||
return metric;
|
||||
})
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
if (CollectionUtils.isNotEmpty(tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) {
|
||||
dimensions = tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
|
||||
.map(id -> dataSetSchema.getElement(SchemaElementType.DIMENSION, id))
|
||||
.filter(Objects::nonNull).collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(
|
||||
tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) {
|
||||
dimensions =
|
||||
tagTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream()
|
||||
.map(
|
||||
id ->
|
||||
dataSetSchema.getElement(
|
||||
SchemaElementType.DIMENSION, id))
|
||||
.filter(Objects::nonNull)
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
}
|
||||
parseInfo.setDimensions(dimensions);
|
||||
@@ -56,5 +75,4 @@ public abstract class DetailListQuery extends DetailSemanticQuery {
|
||||
parseInfo.setOrders(orders);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -25,8 +25,8 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
||||
ChatQueryContext queryCtx) {
|
||||
public List<SchemaElementMatch> match(
|
||||
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext queryCtx) {
|
||||
return super.match(candidateElementMatches, queryCtx);
|
||||
}
|
||||
|
||||
@@ -39,7 +39,8 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
|
||||
if (!needFillDateConf(chatQueryContext)) {
|
||||
return;
|
||||
}
|
||||
Map<Long, DataSetSchema> dataSetSchemaMap = chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
|
||||
Map<Long, DataSetSchema> dataSetSchemaMap =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap();
|
||||
DataSetSchema dataSetSchema = dataSetSchemaMap.get(parseInfo.getDataSetId());
|
||||
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig();
|
||||
|
||||
@@ -63,5 +64,4 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
|
||||
parseInfo.setDateInfo(dateInfo);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class MetricFilterQuery extends MetricSemanticQuery {
|
||||
@@ -46,8 +45,7 @@ public class MetricFilterQuery extends MetricSemanticQuery {
|
||||
|
||||
protected boolean isMultiStructQuery() {
|
||||
Set<String> filterBizName = new HashSet<>();
|
||||
parseInfo.getDimensionFilters().forEach(filter ->
|
||||
filterBizName.add(filter.getBizName()));
|
||||
parseInfo.getDimensionFilters().forEach(filter -> filterBizName.add(filter.getBizName()));
|
||||
return FilterType.UNION.equals(parseInfo.getFilterType()) && filterBizName.size() > 1;
|
||||
}
|
||||
|
||||
@@ -73,17 +71,21 @@ public class MetricFilterQuery extends MetricSemanticQuery {
|
||||
log.debug("addDimension before [{}]", queryStructReq.getGroups());
|
||||
List<Filter> filters = new ArrayList<>(queryStructReq.getDimensionFilters());
|
||||
if (onlyOperateInFilter) {
|
||||
filters = filters.stream().filter(filter
|
||||
-> filter.getOperator().equals(FilterOperatorEnum.IN)).collect(Collectors.toList());
|
||||
filters =
|
||||
filters.stream()
|
||||
.filter(
|
||||
filter ->
|
||||
filter.getOperator().equals(FilterOperatorEnum.IN))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
filters.forEach(d -> {
|
||||
if (!dimensions.contains(d.getBizName())) {
|
||||
dimensions.add(d.getBizName());
|
||||
}
|
||||
});
|
||||
filters.forEach(
|
||||
d -> {
|
||||
if (!dimensions.contains(d.getBizName())) {
|
||||
dimensions.add(d.getBizName());
|
||||
}
|
||||
});
|
||||
queryStructReq.setGroups(dimensions);
|
||||
log.debug("addDimension after [{}]", queryStructReq.getGroups());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -23,5 +23,4 @@ public class MetricGroupByQuery extends MetricSemanticQuery {
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -28,8 +28,7 @@ public class MetricIdQuery extends MetricSemanticQuery {
|
||||
|
||||
public MetricIdQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1)
|
||||
.addOption(ENTITY, REQUIRED, AT_LEAST, 1);
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1).addOption(ENTITY, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -75,17 +74,21 @@ public class MetricIdQuery extends MetricSemanticQuery {
|
||||
log.info("addDimension before [{}]", queryStructReq.getGroups());
|
||||
List<Filter> filters = new ArrayList<>(queryStructReq.getDimensionFilters());
|
||||
if (onlyOperateInFilter) {
|
||||
filters = filters.stream().filter(filter
|
||||
-> filter.getOperator().equals(FilterOperatorEnum.IN)).collect(Collectors.toList());
|
||||
filters =
|
||||
filters.stream()
|
||||
.filter(
|
||||
filter ->
|
||||
filter.getOperator().equals(FilterOperatorEnum.IN))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
filters.forEach(d -> {
|
||||
if (!dimensions.contains(d.getBizName())) {
|
||||
dimensions.add(d.getBizName());
|
||||
}
|
||||
});
|
||||
filters.forEach(
|
||||
d -> {
|
||||
if (!dimensions.contains(d.getBizName())) {
|
||||
dimensions.add(d.getBizName());
|
||||
}
|
||||
});
|
||||
queryStructReq.setGroups(dimensions);
|
||||
log.info("addDimension after [{}]", queryStructReq.getGroups());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -20,5 +20,4 @@ public class MetricModelQuery extends MetricSemanticQuery {
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.metric;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
@@ -29,8 +28,8 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
||||
ChatQueryContext queryCtx) {
|
||||
public List<SchemaElementMatch> match(
|
||||
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext queryCtx) {
|
||||
return super.match(candidateElementMatches, queryCtx);
|
||||
}
|
||||
|
||||
@@ -46,11 +45,15 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
|
||||
return;
|
||||
}
|
||||
DataSetSchema dataSetSchema =
|
||||
chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
|
||||
chatQueryContext
|
||||
.getSemanticSchema()
|
||||
.getDataSetSchemaMap()
|
||||
.get(parseInfo.getDataSetId());
|
||||
TimeDefaultConfig timeDefaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig();
|
||||
DateConf dateInfo = new DateConf();
|
||||
//加上时间!=-1 判断
|
||||
if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit())
|
||||
// 加上时间!=-1 判断
|
||||
if (Objects.nonNull(timeDefaultConfig)
|
||||
&& Objects.nonNull(timeDefaultConfig.getUnit())
|
||||
&& timeDefaultConfig.getUnit() != -1) {
|
||||
int unit = timeDefaultConfig.getUnit();
|
||||
String startDate = LocalDate.now().plusDays(-unit).toString();
|
||||
|
||||
@@ -34,8 +34,8 @@ public class MetricTopNQuery extends MetricSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
||||
ChatQueryContext queryCtx) {
|
||||
public List<SchemaElementMatch> match(
|
||||
List<SchemaElementMatch> candidateElementMatches, ChatQueryContext queryCtx) {
|
||||
Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getQueryText());
|
||||
if (matcher.matches()) {
|
||||
return super.match(candidateElementMatches, queryCtx);
|
||||
@@ -59,5 +59,4 @@ public class MetricTopNQuery extends MetricSemanticQuery {
|
||||
SchemaElement metric = parseInfo.getMetrics().iterator().next();
|
||||
parseInfo.getOrders().add(new Order(metric.getBizName(), Constants.DESC_UPPER));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -8,9 +8,7 @@ import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* HeadlessConverter QueryOptimizer QueryExecutor object factory
|
||||
*/
|
||||
/** HeadlessConverter QueryOptimizer QueryExecutor object factory */
|
||||
@Slf4j
|
||||
public class ComponentFactory {
|
||||
|
||||
@@ -28,14 +26,15 @@ public class ComponentFactory {
|
||||
}
|
||||
|
||||
private static <T> List<T> init(Class<T> factoryType, List list) {
|
||||
list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
|
||||
Thread.currentThread().getContextClassLoader()));
|
||||
list.addAll(
|
||||
SpringFactoriesLoader.loadFactories(
|
||||
factoryType, Thread.currentThread().getContextClassLoader()));
|
||||
return list;
|
||||
}
|
||||
|
||||
private static <T> T init(Class<T> factoryType) {
|
||||
return SpringFactoriesLoader.loadFactories(factoryType,
|
||||
Thread.currentThread().getContextClassLoader()).get(0);
|
||||
return SpringFactoriesLoader.loadFactories(
|
||||
factoryType, Thread.currentThread().getContextClassLoader())
|
||||
.get(0);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -13,9 +13,10 @@ public class QueryFilterParser {
|
||||
|
||||
public static String parse(QueryFilters queryFilters) {
|
||||
try {
|
||||
List<String> conditions = queryFilters.getFilters().stream()
|
||||
.map(QueryFilterParser::parseFilter)
|
||||
.collect(Collectors.toList());
|
||||
List<String> conditions =
|
||||
queryFilters.getFilters().stream()
|
||||
.map(QueryFilterParser::parseFilter)
|
||||
.collect(Collectors.toList());
|
||||
return String.join(" AND ", conditions);
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
@@ -35,9 +36,14 @@ public class QueryFilterParser {
|
||||
case BETWEEN:
|
||||
if (value instanceof List && ((List<?>) value).size() == 2) {
|
||||
List<?> values = (List<?>) value;
|
||||
return column + " BETWEEN " + formatValue(values.get(0)) + " AND " + formatValue(values.get(1));
|
||||
return column
|
||||
+ " BETWEEN "
|
||||
+ formatValue(values.get(0))
|
||||
+ " AND "
|
||||
+ formatValue(values.get(1));
|
||||
}
|
||||
throw new IllegalArgumentException("BETWEEN operator requires a list of two values");
|
||||
throw new IllegalArgumentException(
|
||||
"BETWEEN operator requires a list of two values");
|
||||
case IS_NULL:
|
||||
case IS_NOT_NULL:
|
||||
return column + " " + operator.getValue();
|
||||
@@ -52,9 +58,8 @@ public class QueryFilterParser {
|
||||
|
||||
private static String parseList(Object value) {
|
||||
if (value instanceof List) {
|
||||
return ((List<?>) value).stream()
|
||||
.map(QueryFilterParser::formatValue)
|
||||
.collect(Collectors.joining(", "));
|
||||
return ((List<?>) value)
|
||||
.stream().map(QueryFilterParser::formatValue).collect(Collectors.joining(", "));
|
||||
}
|
||||
throw new IllegalArgumentException("IN and NOT IN operators require a list of values");
|
||||
}
|
||||
@@ -69,4 +74,4 @@ public class QueryFilterParser {
|
||||
}
|
||||
throw new IllegalArgumentException("Unsupported value type: " + value.getClass());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user