mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
(improvement)(headless)Introduce headless-chat. #1155
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
package com.tencent.supersonic.headless.chat;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatContext {
|
||||
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
private String user;
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package com.tencent.supersonic.headless.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.WorkflowState;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class QueryContext {
|
||||
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Set<Long> dataSetIds;
|
||||
private Map<Long, List<Long>> modelIdToDataSetIds;
|
||||
private User user;
|
||||
private boolean saveAnswer;
|
||||
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
||||
private QueryFilters queryFilters;
|
||||
private List<SemanticQuery> candidateQueries = new ArrayList<>();
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
@JsonIgnore
|
||||
private SemanticSchema semanticSchema;
|
||||
@JsonIgnore
|
||||
private WorkflowState workflowState;
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private LLMConfig llmConfig;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
int parseShowCount = Integer.valueOf(parserConfig.getParameterValue(ParserConfig.PARSER_SHOW_COUNT));
|
||||
candidateQueries = candidateQueries.stream()
|
||||
.sorted(Comparator.comparing(semanticQuery -> semanticQuery.getParseInfo().getScore(),
|
||||
Comparator.reverseOrder()))
|
||||
.limit(parseShowCount)
|
||||
.collect(Collectors.toList());
|
||||
return candidateQueries;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Verify whether the SQL aggregate function is missing. If it is missing, fill it in.
|
||||
*/
|
||||
@Slf4j
|
||||
public class AggCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
addAggregate(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addAggregate(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
List<String> sqlGroupByFields = SqlSelectHelper.getGroupByFields(
|
||||
semanticParseInfo.getSqlInfo().getCorrectS2SQL());
|
||||
if (CollectionUtils.isEmpty(sqlGroupByFields)) {
|
||||
return;
|
||||
}
|
||||
addAggregateToMetric(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* basic semantic correction functionality, offering common methods and an
|
||||
* abstract method called doCorrect
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticCorrector implements SemanticCorrector {
|
||||
|
||||
public void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
try {
|
||||
if (StringUtils.isBlank(semanticParseInfo.getSqlInfo().getCorrectS2SQL())) {
|
||||
return;
|
||||
}
|
||||
doCorrect(queryContext, semanticParseInfo);
|
||||
log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo());
|
||||
} catch (Exception e) {
|
||||
log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e);
|
||||
}
|
||||
}
|
||||
|
||||
public abstract void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
||||
|
||||
protected Map<String, String> getFieldNameMap(QueryContext queryContext, Long dataSetId) {
|
||||
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
List<SchemaElement> dbAllFields = new ArrayList<>();
|
||||
dbAllFields.addAll(semanticSchema.getMetrics());
|
||||
dbAllFields.addAll(semanticSchema.getDimensions());
|
||||
|
||||
// support fieldName and field alias
|
||||
Map<String, String> result = dbAllFields.stream()
|
||||
.filter(entry -> dataSetId.equals(entry.getDataSet()))
|
||||
.flatMap(schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
})
|
||||
.collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1));
|
||||
result.put(TimeDimensionEnum.DAY.getChName(), TimeDimensionEnum.DAY.getChName());
|
||||
result.put(TimeDimensionEnum.MONTH.getChName(), TimeDimensionEnum.MONTH.getChName());
|
||||
result.put(TimeDimensionEnum.WEEK.getChName(), TimeDimensionEnum.WEEK.getChName());
|
||||
|
||||
result.put(TimeDimensionEnum.DAY.getName(), TimeDimensionEnum.DAY.getChName());
|
||||
result.put(TimeDimensionEnum.MONTH.getName(), TimeDimensionEnum.MONTH.getChName());
|
||||
result.put(TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.WEEK.getChName());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
protected String addFieldsToSelect(SemanticParseInfo semanticParseInfo, String correctS2SQL) {
|
||||
Set<String> selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL));
|
||||
Set<String> needAddFields = new HashSet<>(SqlSelectHelper.getGroupByFields(correctS2SQL));
|
||||
|
||||
//decide whether add order by expression field to select
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
needAddFields.addAll(SqlSelectHelper.getOrderByFields(correctS2SQL));
|
||||
}
|
||||
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(needAddFields)) {
|
||||
return correctS2SQL;
|
||||
}
|
||||
|
||||
needAddFields.removeAll(selectFields);
|
||||
String addFieldsToSelectSql = SqlAddHelper.addFieldsToSelect(correctS2SQL, new ArrayList<>(needAddFields));
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(addFieldsToSelectSql);
|
||||
return addFieldsToSelectSql;
|
||||
}
|
||||
|
||||
protected void addAggregateToMetric(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
//add aggregate to all metric
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
Long dataSetId = semanticParseInfo.getDataSet().getDataSet();
|
||||
List<SchemaElement> metrics = getMetricElements(queryContext, dataSetId);
|
||||
|
||||
Map<String, String> metricToAggregate = metrics.stream()
|
||||
.map(schemaElement -> {
|
||||
if (Objects.isNull(schemaElement.getDefaultAgg())) {
|
||||
schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name());
|
||||
}
|
||||
return schemaElement;
|
||||
}).flatMap(schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream().map(element -> Pair.of(element, schemaElement.getDefaultAgg())
|
||||
);
|
||||
}).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
|
||||
|
||||
if (CollectionUtils.isEmpty(metricToAggregate)) {
|
||||
return;
|
||||
}
|
||||
String aggregateSql = SqlAddHelper.addAggregateToField(correctS2SQL, metricToAggregate);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(aggregateSql);
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMetricElements(QueryContext queryContext, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
return semanticSchema.getMetrics(dataSetId);
|
||||
}
|
||||
|
||||
protected Set<String> getDimensions(Long dataSetId, SemanticSchema semanticSchema) {
|
||||
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
|
||||
.flatMap(
|
||||
schemaElement -> {
|
||||
Set<String> elements = new HashSet<>();
|
||||
elements.add(schemaElement.getName());
|
||||
if (!CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
elements.addAll(schemaElement.getAlias());
|
||||
}
|
||||
return elements.stream();
|
||||
}
|
||||
).collect(Collectors.toSet());
|
||||
dimensions.add(TimeDimensionEnum.DAY.getChName());
|
||||
return dimensions;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Correcting SQL syntax, primarily including fixes to select, where, groupBy, and Having clauses
|
||||
*/
|
||||
@Slf4j
|
||||
public class GrammarCorrector extends BaseSemanticCorrector {
|
||||
|
||||
private List<BaseSemanticCorrector> correctors;
|
||||
|
||||
public GrammarCorrector() {
|
||||
correctors = new ArrayList<>();
|
||||
correctors.add(new SelectCorrector());
|
||||
correctors.add(new WhereCorrector());
|
||||
correctors.add(new GroupByCorrector());
|
||||
correctors.add(new AggCorrector());
|
||||
correctors.add(new HavingCorrector());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
for (BaseSemanticCorrector corrector : correctors) {
|
||||
corrector.correct(queryContext, semanticParseInfo);
|
||||
}
|
||||
removeSameFieldFromSelect(semanticParseInfo);
|
||||
}
|
||||
|
||||
public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Group by" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class GroupByCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Boolean needAddGroupBy = needAddGroupBy(queryContext, semanticParseInfo);
|
||||
if (!needAddGroupBy) {
|
||||
return;
|
||||
}
|
||||
addGroupByFields(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private Boolean needAddGroupBy(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
// check has distinct
|
||||
if (SqlSelectHelper.hasDistinct(correctS2SQL)) {
|
||||
log.info("not add group by ,exist distinct in correctS2SQL:{}", correctS2SQL);
|
||||
return false;
|
||||
}
|
||||
//add alias field name
|
||||
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(dimensions)) {
|
||||
return false;
|
||||
}
|
||||
// if only date in select not add group by.
|
||||
if (selectFields.size() == 1 && selectFields.contains(TimeDimensionEnum.DAY.getChName())) {
|
||||
return false;
|
||||
}
|
||||
if (SqlSelectHelper.hasGroupBy(correctS2SQL)) {
|
||||
log.info("not add group by ,exist group by in correctS2SQL:{}", correctS2SQL);
|
||||
return false;
|
||||
}
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo) && !Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private void addGroupByFields(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
//add dimension group by
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
//add alias field name
|
||||
Set<String> dimensions = getDimensions(dataSetId, semanticSchema);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
Set<String> groupByFields = selectFields.stream()
|
||||
.filter(field -> dimensions.contains(field))
|
||||
.filter(field -> {
|
||||
if (!CollectionUtils.isEmpty(aggregateFields) && aggregateFields.contains(field)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectFunctionHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Having" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class HavingCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
//add aggregate to all metric
|
||||
addHaving(queryContext, semanticParseInfo);
|
||||
|
||||
//decide whether add having expression field to select
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorAdditionalInfo = environment.getProperty("s2.corrector.additional.information");
|
||||
if (StringUtils.isNotBlank(correctorAdditionalInfo) && Boolean.parseBoolean(correctorAdditionalInfo)) {
|
||||
addHavingToSelect(semanticParseInfo);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void addHaving(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Long dataSet = semanticParseInfo.getDataSet().getDataSet();
|
||||
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
Set<String> metrics = semanticSchema.getMetrics(dataSet).stream()
|
||||
.map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet());
|
||||
|
||||
if (CollectionUtils.isEmpty(metrics)) {
|
||||
return;
|
||||
}
|
||||
String havingSql = SqlAddHelper.addHaving(semanticParseInfo.getSqlInfo().getCorrectS2SQL(), metrics);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(havingSql);
|
||||
}
|
||||
|
||||
private void addHavingToSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
if (!SqlSelectFunctionHelper.hasAggregateFunction(correctS2SQL)) {
|
||||
return;
|
||||
}
|
||||
List<Expression> havingExpressionList = SqlSelectHelper.getHavingExpression(correctS2SQL);
|
||||
if (!CollectionUtils.isEmpty(havingExpressionList)) {
|
||||
String replaceSql = SqlAddHelper.addFunctionToSelect(correctS2SQL, havingExpressionList);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(replaceSql);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,148 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.AggregateEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.FieldExpression;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.ParseResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform schema corrections on the Schema information in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SchemaCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
correctAggFunction(semanticParseInfo);
|
||||
|
||||
replaceAlias(semanticParseInfo);
|
||||
|
||||
updateFieldNameByLinkingValue(semanticParseInfo);
|
||||
|
||||
updateFieldValueByLinkingValue(semanticParseInfo);
|
||||
|
||||
correctFieldName(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void correctAggFunction(SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> aggregateEnum = AggregateEnum.getAggregateEnum();
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlReplaceHelper.replaceFunction(sqlInfo.getCorrectS2SQL(), aggregateEnum);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private void replaceAlias(SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String replaceAlias = SqlReplaceHelper.replaceAlias(sqlInfo.getCorrectS2SQL());
|
||||
sqlInfo.setCorrectS2SQL(replaceAlias);
|
||||
}
|
||||
|
||||
private void correctFieldName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
Map<String, String> fieldNameMap = getFieldNameMap(queryContext, semanticParseInfo.getDataSetId());
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlReplaceHelper.replaceFields(sqlInfo.getCorrectS2SQL(), fieldNameMap);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private void updateFieldNameByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||
List<LLMReq.ElementValue> linking = getLinkingValues(semanticParseInfo);
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Set<String>> fieldValueToFieldNames = linking.stream().collect(
|
||||
Collectors.groupingBy(LLMReq.ElementValue::getFieldValue,
|
||||
Collectors.mapping(LLMReq.ElementValue::getFieldName, Collectors.toSet())));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
|
||||
String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectS2SQL(), fieldValueToFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
private List<LLMReq.ElementValue> getLinkingValues(SemanticParseInfo semanticParseInfo) {
|
||||
Object context = semanticParseInfo.getProperties().get(Constants.CONTEXT);
|
||||
if (Objects.isNull(context)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
ParseResult parseResult = JsonUtil.toObject(JsonUtil.toString(context), ParseResult.class);
|
||||
if (Objects.isNull(parseResult) || Objects.isNull(parseResult.getLlmReq())) {
|
||||
return null;
|
||||
}
|
||||
return parseResult.getLinkingValues();
|
||||
}
|
||||
|
||||
private void updateFieldValueByLinkingValue(SemanticParseInfo semanticParseInfo) {
|
||||
List<LLMReq.ElementValue> linking = getLinkingValues(semanticParseInfo);
|
||||
if (CollectionUtils.isEmpty(linking)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> filedNameToValueMap = linking.stream().collect(
|
||||
Collectors.groupingBy(LLMReq.ElementValue::getFieldName,
|
||||
Collectors.mapping(LLMReq.ElementValue::getFieldValue, Collectors.toMap(
|
||||
oldValue -> oldValue,
|
||||
newValue -> newValue,
|
||||
(existingValue, newValue) -> newValue)
|
||||
)));
|
||||
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectS2SQL(), filedNameToValueMap, false);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
public void removeFilterIfNotInLinkingValue(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
SqlInfo sqlInfo = semanticParseInfo.getSqlInfo();
|
||||
String correctS2SQL = sqlInfo.getCorrectS2SQL();
|
||||
List<FieldExpression> whereExpressionList = SqlSelectHelper.getWhereExpressions(correctS2SQL);
|
||||
if (CollectionUtils.isEmpty(whereExpressionList)) {
|
||||
return;
|
||||
}
|
||||
List<LLMReq.ElementValue> linkingValues = getLinkingValues(semanticParseInfo);
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
Set<String> dimensions = getDimensions(semanticParseInfo.getDataSetId(), semanticSchema);
|
||||
|
||||
if (CollectionUtils.isEmpty(linkingValues)) {
|
||||
linkingValues = new ArrayList<>();
|
||||
}
|
||||
Set<String> linkingFieldNames = linkingValues.stream().map(linking -> linking.getFieldName())
|
||||
.collect(Collectors.toSet());
|
||||
|
||||
Set<String> removeFieldNames = whereExpressionList.stream()
|
||||
.filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction()))
|
||||
.filter(fieldExpression -> !TimeDimensionEnum.containsTimeDimension(fieldExpression.getFieldName()))
|
||||
.filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue().equals(fieldExpression.getOperator()))
|
||||
.filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName()))
|
||||
.filter(fieldExpression -> !DateUtils.isAnyDateString(fieldExpression.getFieldValue().toString()))
|
||||
.filter(fieldExpression -> !linkingFieldNames.contains(fieldExpression.getFieldName()))
|
||||
.map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet());
|
||||
|
||||
String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
|
||||
sqlInfo.setCorrectS2SQL(sql);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Select" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SelectCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL);
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(correctS2SQL);
|
||||
// If the number of aggregated fields is equal to the number of queried fields, do not add fields to select.
|
||||
if (!CollectionUtils.isEmpty(aggregateFields)
|
||||
&& !CollectionUtils.isEmpty(selectFields)
|
||||
&& aggregateFields.size() == selectFields.size()) {
|
||||
return;
|
||||
}
|
||||
correctS2SQL = addFieldsToSelect(semanticParseInfo, correctS2SQL);
|
||||
String querySql = SqlReplaceHelper.dealAliasToOrderBy(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(querySql);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
|
||||
/**
|
||||
* A semantic corrector checks validity of extracted semantic information and
|
||||
* performs correction and optimization if needed.
|
||||
*/
|
||||
public interface SemanticCorrector {
|
||||
|
||||
void correct(QueryContext queryContext, SemanticParseInfo semanticParseInfo);
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlDateSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.DateVisitor.DateBoundInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.core.env.Environment;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the time in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class TimeCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
addDateIfNotExist(queryContext, semanticParseInfo);
|
||||
|
||||
removeDateIfExist(queryContext, semanticParseInfo);
|
||||
|
||||
parserDateDiffFunction(semanticParseInfo);
|
||||
|
||||
addLowerBoundDate(semanticParseInfo);
|
||||
|
||||
}
|
||||
|
||||
private void removeDateIfExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
//decide whether remove date field from where
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorDate = environment.getProperty("s2.corrector.date");
|
||||
if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) {
|
||||
Set<String> removeFieldNames = new HashSet<>();
|
||||
removeFieldNames.add(TimeDimensionEnum.DAY.getChName());
|
||||
removeFieldNames.add(TimeDimensionEnum.WEEK.getChName());
|
||||
removeFieldNames.add(TimeDimensionEnum.MONTH.getChName());
|
||||
correctS2SQL = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void addDateIfNotExist(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(correctS2SQL);
|
||||
|
||||
//decide whether add date field to where
|
||||
Environment environment = ContextUtils.getBean(Environment.class);
|
||||
String correctorDate = environment.getProperty("s2.corrector.date");
|
||||
log.info("correctorDate:{}", correctorDate);
|
||||
if (StringUtils.isNotBlank(correctorDate) && !Boolean.parseBoolean(correctorDate)) {
|
||||
return;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(whereFields) || !TimeDimensionEnum.containsZhTimeDimension(whereFields)) {
|
||||
|
||||
Pair<String, String> startEndDate = S2SqlDateHelper.getStartEndDate(queryContext,
|
||||
semanticParseInfo.getDataSetId(), semanticParseInfo.getQueryType());
|
||||
|
||||
if (StringUtils.isNotBlank(startEndDate.getLeft())
|
||||
&& StringUtils.isNotBlank(startEndDate.getRight())) {
|
||||
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
String dateChName = TimeDimensionEnum.DAY.getChName();
|
||||
String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", dateChName,
|
||||
startEndDate.getLeft(), dateChName, startEndDate.getRight());
|
||||
try {
|
||||
Expression expression = CCJSqlParserUtil.parseCondExpression(condExpr);
|
||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression:{}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private void addLowerBoundDate(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL);
|
||||
if (Objects.isNull(dateBoundInfo)) {
|
||||
return;
|
||||
}
|
||||
if (StringUtils.isBlank(dateBoundInfo.getLowerBound())
|
||||
&& StringUtils.isNotBlank(dateBoundInfo.getUpperBound())
|
||||
&& StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) {
|
||||
String upperDate = dateBoundInfo.getUpperDate();
|
||||
try {
|
||||
correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL);
|
||||
String condExpr = dateBoundInfo.getColumName() + " >= '" + upperDate + "'";
|
||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, CCJSqlParserUtil.parseCondExpression(condExpr));
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private void parserDateDiffFunction(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctS2SQL = SqlReplaceHelper.replaceFunction(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package com.tencent.supersonic.headless.chat.corrector;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.StringUtil;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlAddHelper;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlReplaceHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import net.sf.jsqlparser.JSQLParserException;
|
||||
import net.sf.jsqlparser.expression.Expression;
|
||||
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Perform SQL corrections on the "Where" section in S2SQL.
|
||||
*/
|
||||
@Slf4j
|
||||
public class WhereCorrector extends BaseSemanticCorrector {
|
||||
|
||||
@Override
|
||||
public void doCorrect(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
|
||||
addQueryFilter(queryContext, semanticParseInfo);
|
||||
|
||||
updateFieldValueByTechName(queryContext, semanticParseInfo);
|
||||
}
|
||||
|
||||
private void addQueryFilter(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
String queryFilter = getQueryFilter(queryContext.getQueryFilters());
|
||||
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
|
||||
if (StringUtils.isNotEmpty(queryFilter)) {
|
||||
log.info("add queryFilter to correctS2SQL :{}", queryFilter);
|
||||
Expression expression = null;
|
||||
try {
|
||||
expression = CCJSqlParserUtil.parseCondExpression(queryFilter);
|
||||
} catch (JSQLParserException e) {
|
||||
log.error("parseCondExpression", e);
|
||||
}
|
||||
correctS2SQL = SqlAddHelper.addWhere(correctS2SQL, expression);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
private String getQueryFilter(QueryFilters queryFilters) {
|
||||
if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||
return null;
|
||||
}
|
||||
return queryFilters.getFilters().stream()
|
||||
.map(filter -> {
|
||||
String bizNameWrap = StringUtil.getSpaceWrap(filter.getName());
|
||||
String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue());
|
||||
String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString());
|
||||
return bizNameWrap + operatorWrap + valueWrap;
|
||||
})
|
||||
.collect(Collectors.joining(Constants.AND_UPPER));
|
||||
}
|
||||
|
||||
private void updateFieldValueByTechName(QueryContext queryContext, SemanticParseInfo semanticParseInfo) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
Long dataSetId = semanticParseInfo.getDataSetId();
|
||||
List<SchemaElement> dimensions = semanticSchema.getDimensions(dataSetId);
|
||||
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> aliasAndBizNameToTechName = getAliasAndBizNameToTechName(dimensions);
|
||||
String correctS2SQL = SqlReplaceHelper.replaceValue(semanticParseInfo.getSqlInfo().getCorrectS2SQL(),
|
||||
aliasAndBizNameToTechName);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
|
||||
private Map<String, Map<String, String>> getAliasAndBizNameToTechName(List<SchemaElement> dimensions) {
|
||||
if (CollectionUtils.isEmpty(dimensions)) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
|
||||
Map<String, Map<String, String>> result = new HashMap<>();
|
||||
|
||||
for (SchemaElement dimension : dimensions) {
|
||||
if (Objects.isNull(dimension)
|
||||
|| Strings.isEmpty(dimension.getName())
|
||||
|| CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) {
|
||||
continue;
|
||||
}
|
||||
String name = dimension.getName();
|
||||
|
||||
Map<String, String> aliasAndBizNameToTechName = new HashMap<>();
|
||||
|
||||
for (SchemaValueMap valueMap : dimension.getSchemaValueMaps()) {
|
||||
if (Objects.isNull(valueMap) || Strings.isEmpty(valueMap.getTechName())) {
|
||||
continue;
|
||||
}
|
||||
if (Strings.isNotEmpty(valueMap.getBizName())) {
|
||||
aliasAndBizNameToTechName.put(valueMap.getBizName(), valueMap.getTechName());
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(valueMap.getAlias())) {
|
||||
valueMap.getAlias().stream().forEach(alias -> {
|
||||
if (Strings.isNotEmpty(alias)) {
|
||||
aliasAndBizNameToTechName.put(alias, valueMap.getTechName());
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(aliasAndBizNameToTechName)) {
|
||||
result.put(name, aliasAndBizNameToTechName);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@Builder
|
||||
public class DataSetInfoStat implements Serializable {
|
||||
|
||||
private long dataSetCount;
|
||||
|
||||
private long metricDataSetCount;
|
||||
|
||||
private long dimensionDataSetCount;
|
||||
|
||||
private long dimensionValueDataSetCount;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class DatabaseMapResult extends MapResult {
|
||||
|
||||
private SchemaElement schemaElement;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
DatabaseMapResult that = (DatabaseMapResult) o;
|
||||
return Objects.equal(name, that.name) && Objects.equal(schemaElement, that.schemaElement);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(name, schemaElement);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
public enum DictUpdateMode {
|
||||
|
||||
OFFLINE_FULL("OFFLINE_FULL"),
|
||||
OFFLINE_MODEL("OFFLINE_MODEL"),
|
||||
REALTIME_ADD("REALTIME_ADD"),
|
||||
REALTIME_DELETE("REALTIME_DELETE"),
|
||||
NOT_SUPPORT("NOT_SUPPORT");
|
||||
|
||||
private String value;
|
||||
|
||||
DictUpdateMode(String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
public static DictUpdateMode of(String value) {
|
||||
for (DictUpdateMode item : DictUpdateMode.values()) {
|
||||
if (item.value.equalsIgnoreCase(value)) {
|
||||
return item;
|
||||
}
|
||||
}
|
||||
return DictUpdateMode.NOT_SUPPORT;
|
||||
}
|
||||
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import java.util.Objects;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
/***
|
||||
* word nature
|
||||
*/
|
||||
@Data
|
||||
@ToString
|
||||
public class DictWord {
|
||||
|
||||
private String word;
|
||||
private String nature;
|
||||
private String natureWithFrequency;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
DictWord that = (DictWord) o;
|
||||
return Objects.equals(word, that.word) && Objects.equals(natureWithFrequency, that.natureWithFrequency);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(word, natureWithFrequency);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.hankcs.hanlp.dictionary.CoreDictionary;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
/**
|
||||
* Dictionary Attribute Util
|
||||
*/
|
||||
public class DictionaryAttributeUtil {
|
||||
|
||||
public static CoreDictionary.Attribute getAttribute(CoreDictionary.Attribute old, CoreDictionary.Attribute add) {
|
||||
Map<Nature, Integer> map = new HashMap<>();
|
||||
IntStream.range(0, old.nature.length).boxed().forEach(i -> map.put(old.nature[i], old.frequency[i]));
|
||||
IntStream.range(0, add.nature.length).boxed().forEach(i -> map.put(add.nature[i], add.frequency[i]));
|
||||
List<Map.Entry<Nature, Integer>> list = new LinkedList<Map.Entry<Nature, Integer>>(map.entrySet());
|
||||
Collections.sort(list, new Comparator<Map.Entry<Nature, Integer>>() {
|
||||
public int compare(Map.Entry<Nature, Integer> o1, Map.Entry<Nature, Integer> o2) {
|
||||
return o2.getValue() - o1.getValue();
|
||||
}
|
||||
});
|
||||
CoreDictionary.Attribute attribute = new CoreDictionary.Attribute(
|
||||
list.stream().map(i -> i.getKey()).collect(Collectors.toList()).toArray(new Nature[0]),
|
||||
list.stream().map(i -> i.getValue()).mapToInt(Integer::intValue).toArray(),
|
||||
list.stream().map(i -> i.getValue()).findFirst().get());
|
||||
if (old.original != null || add.original != null) {
|
||||
attribute.original = add.original != null ? add.original : old.original;
|
||||
}
|
||||
return attribute;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import java.util.Map;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class EmbeddingResult extends MapResult {
|
||||
|
||||
private String id;
|
||||
|
||||
private double distance;
|
||||
|
||||
private Map<String, String> metadata;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
EmbeddingResult that = (EmbeddingResult) o;
|
||||
return Objects.equal(id, that.id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(id);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import com.hankcs.hanlp.corpus.io.IIOAdapter;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.io.OutputStream;
|
||||
import java.net.URI;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
@Slf4j
|
||||
public class HadoopFileIOAdapter implements IIOAdapter {
|
||||
|
||||
@Override
|
||||
public InputStream open(String path) throws IOException {
|
||||
log.info("open:{}", path);
|
||||
Configuration conf = new Configuration();
|
||||
FileSystem fs = FileSystem.get(URI.create(path), conf);
|
||||
return fs.open(new Path(path));
|
||||
}
|
||||
|
||||
@Override
|
||||
public OutputStream create(String path) throws IOException {
|
||||
log.info("create:{}", path);
|
||||
Configuration conf = new Configuration();
|
||||
FileSystem fs = FileSystem.get(URI.create(path), conf);
|
||||
return fs.create(new Path(path));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import java.util.List;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class HanlpMapResult extends MapResult {
|
||||
|
||||
private List<String> natures;
|
||||
private int offset = 0;
|
||||
|
||||
private double similarity;
|
||||
|
||||
public HanlpMapResult(String name, List<String> natures, String detectWord) {
|
||||
this.name = name;
|
||||
this.natures = natures;
|
||||
this.detectWord = detectWord;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
HanlpMapResult hanlpMapResult = (HanlpMapResult) o;
|
||||
return Objects.equal(name, hanlpMapResult.name) && Objects.equal(natures, hanlpMapResult.natures);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(name, natures);
|
||||
}
|
||||
|
||||
public void setOffset(int offset) {
|
||||
this.offset = offset;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class KnowledgeBaseService {
|
||||
|
||||
public void updateSemanticKnowledge(List<DictWord> natures) {
|
||||
|
||||
List<DictWord> prefixes = natures.stream()
|
||||
.filter(entry -> !entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getType()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
for (DictWord nature : prefixes) {
|
||||
HanlpHelper.addToCustomDictionary(nature);
|
||||
}
|
||||
|
||||
List<DictWord> suffixes = natures.stream()
|
||||
.filter(entry -> entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getType()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
SearchService.loadSuffix(suffixes);
|
||||
}
|
||||
|
||||
public void reloadAllData(List<DictWord> natures) {
|
||||
// 1. reload custom knowledge
|
||||
try {
|
||||
HanlpHelper.reloadCustomDictionary();
|
||||
} catch (Exception e) {
|
||||
log.error("reloadCustomDictionary error", e);
|
||||
}
|
||||
|
||||
// 2. update online knowledge
|
||||
updateOnlineKnowledge(natures);
|
||||
}
|
||||
|
||||
public void updateOnlineKnowledge(List<DictWord> natures) {
|
||||
try {
|
||||
updateSemanticKnowledge(natures);
|
||||
} catch (Exception e) {
|
||||
log.error("updateSemanticKnowledge error", e);
|
||||
}
|
||||
}
|
||||
|
||||
public List<S2Term> getTerms(String text, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
return HanlpHelper.getTerms(text, modelIdToDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return prefixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> prefixSearchByModel(String key, int limit,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
return SearchService.prefixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return suffixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public List<HanlpMapResult> suffixSearchByModel(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return SearchService.suffixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import java.io.Serializable;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class MapResult implements Serializable {
|
||||
|
||||
protected String name;
|
||||
protected String detectWord;
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class MetaEmbeddingService {
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
public List<RetrieveQueryResult> retrieveQuery(RetrieveQuery retrieveQuery, int num,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
// dataSetIds->modelIds
|
||||
Set<Long> allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
if (CollectionUtils.isNotEmpty(allModels) && allModels.size() == 1) {
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", allModels.stream().findFirst().get().toString());
|
||||
retrieveQuery.setFilterCondition(filterCondition);
|
||||
}
|
||||
|
||||
String collectionName = embeddingConfig.getMetaCollectionName();
|
||||
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(collectionName, retrieveQuery, num);
|
||||
if (CollectionUtils.isEmpty(resultList)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
//filter by modelId
|
||||
if (CollectionUtils.isEmpty(allModels)) {
|
||||
return resultList;
|
||||
}
|
||||
return resultList.stream()
|
||||
.map(retrieveQueryResult -> {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
if (CollectionUtils.isEmpty(retrievals)) {
|
||||
return retrieveQueryResult;
|
||||
}
|
||||
//filter by modelId
|
||||
retrievals.removeIf(retrieval -> {
|
||||
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
|
||||
if (Objects.isNull(modelId)) {
|
||||
return CollectionUtils.isEmpty(allModels);
|
||||
}
|
||||
return !allModels.contains(modelId);
|
||||
});
|
||||
//add dataSetId
|
||||
retrievals = retrievals.stream().flatMap(retrieval -> {
|
||||
Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId"));
|
||||
List<Long> dataSetIdsByModelId = modelIdToDataSetIds.get(modelId);
|
||||
if (!CollectionUtils.isEmpty(dataSetIdsByModelId)) {
|
||||
Set<Retrieval> result = new HashSet<>();
|
||||
for (Long dataSetId : dataSetIdsByModelId) {
|
||||
Retrieval retrievalNew = new Retrieval();
|
||||
BeanUtils.copyProperties(retrieval, retrievalNew);
|
||||
retrievalNew.getMetadata().putIfAbsent("dataSetId", dataSetId + Constants.UNDERLINE);
|
||||
result.add(retrievalNew);
|
||||
}
|
||||
return result.stream();
|
||||
}
|
||||
Set<Retrieval> result = new HashSet<>();
|
||||
result.add(retrieval);
|
||||
return result.stream();
|
||||
}).collect(Collectors.toList());
|
||||
retrieveQueryResult.setRetrieval(retrievals);
|
||||
return retrieveQueryResult;
|
||||
})
|
||||
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,397 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import static com.hankcs.hanlp.utility.Predefine.logger;
|
||||
|
||||
import com.hankcs.hanlp.HanLP;
|
||||
import com.hankcs.hanlp.collection.trie.DoubleArrayTrie;
|
||||
import com.hankcs.hanlp.collection.trie.bintrie.BinTrie;
|
||||
import com.hankcs.hanlp.corpus.io.ByteArray;
|
||||
import com.hankcs.hanlp.corpus.io.IOUtil;
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.hankcs.hanlp.dictionary.CoreDictionary;
|
||||
import com.hankcs.hanlp.dictionary.DynamicCustomDictionary;
|
||||
import com.hankcs.hanlp.dictionary.other.CharTable;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.hankcs.hanlp.utility.LexiconUtility;
|
||||
import com.hankcs.hanlp.utility.TextUtility;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||
|
||||
import java.io.BufferedOutputStream;
|
||||
import java.io.BufferedReader;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStreamReader;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.TreeMap;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class MultiCustomDictionary extends DynamicCustomDictionary {
|
||||
|
||||
public static int MAX_SIZE = 10;
|
||||
public static Boolean removeDuplicates = true;
|
||||
public static ConcurrentHashMap<String, PriorityQueue<Term>> NATURE_TO_VALUES = new ConcurrentHashMap<>();
|
||||
private static boolean addToSuggesterTrie = true;
|
||||
|
||||
public MultiCustomDictionary() {
|
||||
this(HanLP.Config.CustomDictionaryPath);
|
||||
}
|
||||
|
||||
public MultiCustomDictionary(String... path) {
|
||||
super(path);
|
||||
}
|
||||
|
||||
/***
|
||||
* load dictionary
|
||||
* @param path
|
||||
* @param defaultNature
|
||||
* @param map
|
||||
* @param customNatureCollector
|
||||
* @param addToSuggeterTrie
|
||||
* @return
|
||||
*/
|
||||
public static boolean load(String path, Nature defaultNature, TreeMap<String, CoreDictionary.Attribute> map,
|
||||
LinkedHashSet<Nature> customNatureCollector, boolean addToSuggeterTrie) {
|
||||
try {
|
||||
String splitter = "\\s";
|
||||
if (path.endsWith(".csv")) {
|
||||
splitter = ",";
|
||||
}
|
||||
|
||||
BufferedReader br = new BufferedReader(new InputStreamReader(IOUtil.newInputStream(path), "UTF-8"));
|
||||
boolean firstLine = true;
|
||||
|
||||
while (true) {
|
||||
String[] param;
|
||||
do {
|
||||
String line;
|
||||
if ((line = br.readLine()) == null) {
|
||||
br.close();
|
||||
return true;
|
||||
}
|
||||
|
||||
if (firstLine) {
|
||||
line = IOUtil.removeUTF8BOM(line);
|
||||
firstLine = false;
|
||||
}
|
||||
|
||||
param = line.split(splitter);
|
||||
} while (param[0].length() == 0);
|
||||
|
||||
if (HanLP.Config.Normalization) {
|
||||
param[0] = CharTable.convert(param[0]);
|
||||
}
|
||||
|
||||
int natureCount = (param.length - 1) / 2;
|
||||
CoreDictionary.Attribute attribute;
|
||||
boolean isLetters = isLetters(param[0]);
|
||||
String original = null;
|
||||
String word = getWordBySpace(param[0]);
|
||||
if (isLetters) {
|
||||
original = word;
|
||||
word = word.toLowerCase();
|
||||
}
|
||||
if (natureCount == 0) {
|
||||
attribute = new CoreDictionary.Attribute(defaultNature);
|
||||
} else {
|
||||
attribute = new CoreDictionary.Attribute(natureCount);
|
||||
|
||||
for (int i = 0; i < natureCount; ++i) {
|
||||
attribute.nature[i] = LexiconUtility.convertStringToNature(param[1 + 2 * i],
|
||||
customNatureCollector);
|
||||
attribute.frequency[i] = Integer.parseInt(param[2 + 2 * i]);
|
||||
attribute.totalFrequency += attribute.frequency[i];
|
||||
}
|
||||
}
|
||||
attribute.original = original;
|
||||
|
||||
if (removeDuplicates && map.containsKey(word)) {
|
||||
attribute = DictionaryAttributeUtil.getAttribute(map.get(word), attribute);
|
||||
}
|
||||
map.put(word, attribute);
|
||||
if (addToSuggeterTrie) {
|
||||
SearchService.put(word, attribute);
|
||||
}
|
||||
for (int i = 0; i < attribute.nature.length; i++) {
|
||||
Nature nature = attribute.nature[i];
|
||||
PriorityQueue<Term> priorityQueue = NATURE_TO_VALUES.get(nature.toString());
|
||||
if (Objects.isNull(priorityQueue)) {
|
||||
priorityQueue = new PriorityQueue<>(MAX_SIZE,
|
||||
Comparator.comparingInt(Term::getFrequency).reversed());
|
||||
NATURE_TO_VALUES.put(nature.toString(), priorityQueue);
|
||||
}
|
||||
Term term = new Term(word, nature);
|
||||
term.setFrequency(attribute.frequency[i]);
|
||||
if (!priorityQueue.contains(term) && priorityQueue.size() < MAX_SIZE) {
|
||||
priorityQueue.add(term);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception var12) {
|
||||
logger.severe("自定义词典" + path + "读取错误!" + var12);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public boolean load(String... path) {
|
||||
this.path = path;
|
||||
long start = System.currentTimeMillis();
|
||||
if (!this.loadMainDictionary(path[0])) {
|
||||
logger.warning("自定义词典" + Arrays.toString(path) + "加载失败");
|
||||
return false;
|
||||
} else {
|
||||
logger.info(
|
||||
"自定义词典加载成功:" + this.dat.size() + "个词条,耗时" + (System.currentTimeMillis() - start) + "ms");
|
||||
this.path = path;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/***
|
||||
* load main dictionary
|
||||
* @param mainPath
|
||||
* @param path
|
||||
* @param dat
|
||||
* @param isCache
|
||||
* @param addToSuggestTrie
|
||||
* @return
|
||||
*/
|
||||
public static boolean loadMainDictionary(String mainPath, String[] path,
|
||||
DoubleArrayTrie<CoreDictionary.Attribute> dat, boolean isCache,
|
||||
boolean addToSuggestTrie) {
|
||||
logger.info("自定义词典开始加载:" + mainPath);
|
||||
if (loadDat(mainPath, dat)) {
|
||||
return true;
|
||||
} else {
|
||||
TreeMap<String, CoreDictionary.Attribute> map = new TreeMap();
|
||||
LinkedHashSet customNatureCollector = new LinkedHashSet();
|
||||
|
||||
try {
|
||||
for (String p : path) {
|
||||
Nature defaultNature = Nature.n;
|
||||
File file = new File(p);
|
||||
String fileName = file.getName();
|
||||
int cut = fileName.lastIndexOf(32);
|
||||
if (cut > 0) {
|
||||
String nature = fileName.substring(cut + 1);
|
||||
p = file.getParent() + File.separator + fileName.substring(0, cut);
|
||||
|
||||
try {
|
||||
defaultNature = LexiconUtility.convertStringToNature(nature, customNatureCollector);
|
||||
} catch (Exception var16) {
|
||||
logger.severe("配置文件【" + p + "】写错了!" + var16);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("以默认词性[" + defaultNature + "]加载自定义词典" + p + "中……");
|
||||
boolean success = load(p, defaultNature, map, customNatureCollector, addToSuggestTrie);
|
||||
if (!success) {
|
||||
logger.warning("失败:" + p);
|
||||
}
|
||||
}
|
||||
|
||||
if (map.size() == 0) {
|
||||
logger.warning("没有加载到任何词条");
|
||||
map.put("未##它", null);
|
||||
}
|
||||
|
||||
logger.info("正在构建DoubleArrayTrie……");
|
||||
dat.build(map);
|
||||
if (addToSuggestTrie) {
|
||||
// SearchService.save();
|
||||
}
|
||||
if (isCache) {
|
||||
// 缓存成dat文件,下次加载会快很多
|
||||
logger.info("正在缓存词典为dat文件……");
|
||||
// 缓存值文件
|
||||
List<CoreDictionary.Attribute> attributeList = new LinkedList<CoreDictionary.Attribute>();
|
||||
for (Map.Entry<String, CoreDictionary.Attribute> entry : map.entrySet()) {
|
||||
attributeList.add(entry.getValue());
|
||||
}
|
||||
|
||||
DataOutputStream out = new DataOutputStream(
|
||||
new BufferedOutputStream(IOUtil.newOutputStream(mainPath + ".bin")));
|
||||
if (customNatureCollector.isEmpty()) {
|
||||
for (int i = Nature.begin.ordinal() + 1; i < Nature.values().length; ++i) {
|
||||
Nature nature = Nature.values()[i];
|
||||
if (Objects.nonNull(nature)) {
|
||||
customNatureCollector.add(nature);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
IOUtil.writeCustomNature(out, customNatureCollector);
|
||||
out.writeInt(attributeList.size());
|
||||
|
||||
for (CoreDictionary.Attribute attribute : attributeList) {
|
||||
attribute.save(out);
|
||||
}
|
||||
|
||||
dat.save(out);
|
||||
out.close();
|
||||
}
|
||||
} catch (FileNotFoundException var17) {
|
||||
logger.severe("自定义词典" + mainPath + "不存在!" + var17);
|
||||
return false;
|
||||
} catch (IOException var18) {
|
||||
logger.severe("自定义词典" + mainPath + "读取错误!" + var18);
|
||||
return false;
|
||||
} catch (Exception var19) {
|
||||
logger.warning("自定义词典" + mainPath + "缓存失败!\n" + TextUtility.exceptionToString(var19));
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
public boolean loadMainDictionary(String mainPath) {
|
||||
return loadMainDictionary(mainPath, this.path, this.dat, true, addToSuggesterTrie);
|
||||
}
|
||||
|
||||
public static boolean loadDat(String path, DoubleArrayTrie<CoreDictionary.Attribute> dat) {
|
||||
return loadDat(path, HanLP.Config.CustomDictionaryPath, dat);
|
||||
}
|
||||
|
||||
public static boolean loadDat(String path, String[] customDicPath, DoubleArrayTrie<CoreDictionary.Attribute> dat) {
|
||||
try {
|
||||
if (HanLP.Config.CustomDictionaryAutoRefreshCache
|
||||
&& DynamicCustomDictionary.isDicNeedUpdate(path, customDicPath)) {
|
||||
return false;
|
||||
} else {
|
||||
ByteArray byteArray = ByteArray.createByteArray(path + ".bin");
|
||||
if (byteArray == null) {
|
||||
return false;
|
||||
} else {
|
||||
int size = byteArray.nextInt();
|
||||
if (size < 0) {
|
||||
while (true) {
|
||||
++size;
|
||||
if (size > 0) {
|
||||
size = byteArray.nextInt();
|
||||
break;
|
||||
}
|
||||
|
||||
Nature.create(byteArray.nextString());
|
||||
}
|
||||
}
|
||||
|
||||
CoreDictionary.Attribute[] attributes = new CoreDictionary.Attribute[size];
|
||||
Nature[] natureIndexArray = Nature.values();
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
int currentTotalFrequency = byteArray.nextInt();
|
||||
int length = byteArray.nextInt();
|
||||
attributes[i] = new CoreDictionary.Attribute(length);
|
||||
attributes[i].totalFrequency = currentTotalFrequency;
|
||||
|
||||
for (int j = 0; j < length; ++j) {
|
||||
attributes[i].nature[j] = natureIndexArray[byteArray.nextInt()];
|
||||
attributes[i].frequency[j] = byteArray.nextInt();
|
||||
}
|
||||
}
|
||||
|
||||
if (!dat.load(byteArray, attributes)) {
|
||||
return false;
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Exception var11) {
|
||||
logger.warning("读取失败,问题发生在" + TextUtility.exceptionToString(var11));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public static boolean isLetters(String str) {
|
||||
char[] chars = str.toCharArray();
|
||||
if (chars.length <= 1) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < chars.length; i++) {
|
||||
if ((chars[i] >= 'A' && chars[i] <= 'Z')) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static boolean isLowerLetter(String str) {
|
||||
char[] chars = str.toCharArray();
|
||||
for (int i = 0; i < chars.length; i++) {
|
||||
if ((chars[i] >= 'a' && chars[i] <= 'z')) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static String getWordBySpace(String word) {
|
||||
if (word.contains(HanlpHelper.SPACE_SPILT)) {
|
||||
return word.replace(HanlpHelper.SPACE_SPILT, " ");
|
||||
}
|
||||
return word;
|
||||
}
|
||||
|
||||
public boolean reload() {
|
||||
if (this.path != null && this.path.length != 0) {
|
||||
IOUtil.deleteFile(this.path[0] + ".bin");
|
||||
Boolean loadCacheOk = this.loadDat(this.path[0], this.path, this.dat);
|
||||
if (!loadCacheOk) {
|
||||
return this.loadMainDictionary(this.path[0], this.path, this.dat, true, addToSuggesterTrie);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
||||
}
|
||||
|
||||
public synchronized boolean insert(String word, String natureWithFrequency) {
|
||||
if (word == null) {
|
||||
return false;
|
||||
} else {
|
||||
if (HanLP.Config.Normalization) {
|
||||
word = CharTable.convert(word);
|
||||
}
|
||||
CoreDictionary.Attribute att = natureWithFrequency == null ? new CoreDictionary.Attribute(Nature.nz, 1)
|
||||
: CoreDictionary.Attribute.create(natureWithFrequency);
|
||||
boolean isLetters = isLetters(word);
|
||||
word = getWordBySpace(word);
|
||||
String original = null;
|
||||
if (isLetters) {
|
||||
original = word;
|
||||
word = word.toLowerCase();
|
||||
}
|
||||
if (att == null) {
|
||||
return false;
|
||||
} else if (this.dat.containsKey(word)) {
|
||||
att.original = original;
|
||||
att = DictionaryAttributeUtil.getAttribute(this.dat.get(word), att);
|
||||
this.dat.set(word, att);
|
||||
// return true;
|
||||
} else {
|
||||
if (this.trie == null) {
|
||||
this.trie = new BinTrie();
|
||||
}
|
||||
att.original = original;
|
||||
if (this.trie.containsKey(word)) {
|
||||
att = DictionaryAttributeUtil.getAttribute(this.trie.get(word), att);
|
||||
}
|
||||
this.trie.put(word, att);
|
||||
// return true;
|
||||
}
|
||||
if (addToSuggesterTrie) {
|
||||
SearchService.put(word, att);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,198 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge;
|
||||
|
||||
import com.hankcs.hanlp.collection.trie.bintrie.BaseNode;
|
||||
import com.hankcs.hanlp.collection.trie.bintrie.BinTrie;
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.hankcs.hanlp.dictionary.CoreDictionary;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.Set;
|
||||
import java.util.TreeMap;
|
||||
import java.util.TreeSet;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class SearchService {
|
||||
|
||||
public static final int SEARCH_SIZE = 200;
|
||||
private static BinTrie<List<String>> trie;
|
||||
private static BinTrie<List<String>> suffixTrie;
|
||||
|
||||
static {
|
||||
trie = new BinTrie<>();
|
||||
suffixTrie = new BinTrie<>();
|
||||
}
|
||||
|
||||
/***
|
||||
* prefix Search
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
return prefixSearch(key, limit, trie, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public static List<HanlpMapResult> prefixSearch(String key, int limit, BinTrie<List<String>> binTrie,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = search(key, binTrie);
|
||||
List<HanlpMapResult> hanlpMapResults = result.stream().map(
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
return new HanlpMapResult(name, entry.getValue(), key);
|
||||
}
|
||||
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toList());
|
||||
hanlpMapResults = transformAndFilterByDataSet(hanlpMapResults, modelIdToDataSetIds,
|
||||
detectDataSetIds, limit);
|
||||
return hanlpMapResults;
|
||||
}
|
||||
|
||||
/***
|
||||
* suffix Search
|
||||
* @param key
|
||||
* @return
|
||||
*/
|
||||
public static List<HanlpMapResult> suffixSearch(String key, int limit, Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds) {
|
||||
String reverseDetectSegment = StringUtils.reverse(key);
|
||||
return suffixSearch(reverseDetectSegment, limit, suffixTrie, modelIdToDataSetIds, detectDataSetIds);
|
||||
}
|
||||
|
||||
public static List<HanlpMapResult> suffixSearch(String key, int limit, BinTrie<List<String>> binTrie,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
Set<Map.Entry<String, List<String>>> result = search(key, binTrie);
|
||||
List<HanlpMapResult> hanlpMapResults = result.stream().map(
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
List<String> natures = entry.getValue().stream()
|
||||
.map(nature -> nature.replaceAll(DictWordType.SUFFIX.getType(), ""))
|
||||
.collect(Collectors.toList());
|
||||
name = StringUtils.reverse(name);
|
||||
return new HanlpMapResult(name, natures, key);
|
||||
}
|
||||
).sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toList());
|
||||
return transformAndFilterByDataSet(hanlpMapResults, modelIdToDataSetIds, detectDataSetIds, limit);
|
||||
}
|
||||
|
||||
private static List<HanlpMapResult> transformAndFilterByDataSet(List<HanlpMapResult> hanlpMapResults,
|
||||
Map<Long, List<Long>> modelIdToDataSetIds,
|
||||
Set<Long> detectDataSetIds, int limit) {
|
||||
return hanlpMapResults.stream().peek(hanlpMapResult -> {
|
||||
List<String> natures = hanlpMapResult.getNatures().stream()
|
||||
.map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToDataSetIds))
|
||||
.flatMap(Collection::stream)
|
||||
.filter(nature -> {
|
||||
if (CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||
return true;
|
||||
}
|
||||
Long dataSetId = NatureHelper.getDataSetId(nature);
|
||||
if (dataSetId != null) {
|
||||
return detectDataSetIds.contains(dataSetId);
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
hanlpMapResult.setNatures(natures);
|
||||
}).filter(hanlpMapResult -> !CollectionUtils.isEmpty(hanlpMapResult.getNatures()))
|
||||
.limit(limit).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static Set<Map.Entry<String, List<String>>> search(String key,
|
||||
BinTrie<List<String>> binTrie) {
|
||||
key = key.toLowerCase();
|
||||
Set<Map.Entry<String, List<String>>> entrySet = new TreeSet<Map.Entry<String, List<String>>>();
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
if (StringUtils.isNotBlank(key)) {
|
||||
sb = new StringBuilder(key.substring(0, key.length() - 1));
|
||||
}
|
||||
BaseNode branch = binTrie;
|
||||
char[] chars = key.toCharArray();
|
||||
for (char aChar : chars) {
|
||||
if (branch == null) {
|
||||
return entrySet;
|
||||
}
|
||||
branch = branch.getChild(aChar);
|
||||
}
|
||||
|
||||
if (branch == null) {
|
||||
return entrySet;
|
||||
}
|
||||
branch.walkLimit(sb, entrySet);
|
||||
return entrySet;
|
||||
}
|
||||
|
||||
public static void clear() {
|
||||
log.info("clear all trie");
|
||||
trie = new BinTrie<>();
|
||||
suffixTrie = new BinTrie<>();
|
||||
}
|
||||
|
||||
public static void put(String key, CoreDictionary.Attribute attribute) {
|
||||
trie.put(key, getValue(attribute.nature));
|
||||
}
|
||||
|
||||
public static void loadSuffix(List<DictWord> suffixes) {
|
||||
if (CollectionUtils.isEmpty(suffixes)) {
|
||||
return;
|
||||
}
|
||||
TreeMap<String, CoreDictionary.Attribute> map = new TreeMap();
|
||||
for (DictWord suffix : suffixes) {
|
||||
CoreDictionary.Attribute attributeNew = suffix.getNatureWithFrequency() == null
|
||||
? new CoreDictionary.Attribute(Nature.nz, 1)
|
||||
: CoreDictionary.Attribute.create(suffix.getNatureWithFrequency());
|
||||
if (map.containsKey(suffix.getWord())) {
|
||||
attributeNew = DictionaryAttributeUtil.getAttribute(map.get(suffix.getWord()), attributeNew);
|
||||
}
|
||||
map.put(suffix.getWord(), attributeNew);
|
||||
}
|
||||
for (Map.Entry<String, CoreDictionary.Attribute> stringAttributeEntry : map.entrySet()) {
|
||||
putSuffix(stringAttributeEntry.getKey(), stringAttributeEntry.getValue());
|
||||
}
|
||||
}
|
||||
|
||||
public static void putSuffix(String key, CoreDictionary.Attribute attribute) {
|
||||
Nature[] nature = attribute.nature;
|
||||
suffixTrie.put(key, getValue(nature));
|
||||
}
|
||||
|
||||
private static List<String> getValue(Nature[] nature) {
|
||||
return Arrays.stream(nature).map(entry -> entry.toString()).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static void remove(DictWord dictWord, Nature[] natures) {
|
||||
trie.remove(dictWord.getWord());
|
||||
if (Objects.nonNull(natures) && natures.length > 0) {
|
||||
trie.put(dictWord.getWord(), getValue(natures));
|
||||
}
|
||||
if (dictWord.getNature().contains(DictWordType.METRIC.getType()) || dictWord.getNature()
|
||||
.contains(DictWordType.DIMENSION.getType())) {
|
||||
suffixTrie.remove(dictWord.getWord());
|
||||
}
|
||||
}
|
||||
|
||||
public static List<String> getDimensionValue(DimensionValueReq dimensionValueReq) {
|
||||
String nature = DictWordType.NATURE_SPILT + dimensionValueReq.getModelId() + DictWordType.NATURE_SPILT
|
||||
+ dimensionValueReq.getElementID();
|
||||
PriorityQueue<Term> terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature);
|
||||
if (org.apache.commons.collections.CollectionUtils.isEmpty(terms)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return terms.stream().map(term -> term.getWord()).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
/**
|
||||
* base word nature
|
||||
*/
|
||||
@Slf4j
|
||||
public abstract class BaseWordBuilder {
|
||||
|
||||
public static final Long DEFAULT_FREQUENCY = 100000L;
|
||||
|
||||
public List<DictWord> getDictWords(List<SchemaElement> schemaElements) {
|
||||
List<DictWord> dictWords = new ArrayList<>();
|
||||
try {
|
||||
dictWords = getDictWordsWithException(schemaElements);
|
||||
} catch (Exception e) {
|
||||
log.error("getWordNatureList error,", e);
|
||||
}
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
protected List<DictWord> getDictWordsWithException(List<SchemaElement> schemaElements) {
|
||||
|
||||
List<DictWord> dictWords = new ArrayList<>();
|
||||
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
dictWords.addAll(doGet(schemaElement.getName(), schemaElement));
|
||||
}
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
protected abstract List<DictWord> doGet(String word, SchemaElement schemaElement);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
public abstract class BaseWordWithAliasBuilder extends BaseWordBuilder {
|
||||
|
||||
public abstract DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix);
|
||||
|
||||
public List<DictWord> getOneWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) {
|
||||
List<DictWord> dictWords = new ArrayList<>();
|
||||
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
for (String alias : schemaElement.getAlias()) {
|
||||
dictWords.add(getOneWordNature(alias, schemaElement, isSuffix));
|
||||
}
|
||||
return dictWords;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* dimension word nature
|
||||
*/
|
||||
@Service
|
||||
public class DimensionWordBuilder extends BaseWordWithAliasBuilder {
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
result.add(getOneWordNature(word, schemaElement, false));
|
||||
result.addAll(getOneWordNatureAlias(schemaElement, false));
|
||||
String reverseWord = StringUtils.reverse(word);
|
||||
if (StringUtils.isNotEmpty(word) && !word.equalsIgnoreCase(reverseWord)) {
|
||||
result.add(getOneWordNature(reverseWord, schemaElement, true));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.DIMENSION.getType();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType() + DictWordType.DIMENSION.getType();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class EntityWordBuilder extends BaseWordWithAliasBuilder {
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
if (Objects.isNull(schemaElement)) {
|
||||
return result;
|
||||
}
|
||||
result.add(getOneWordNature(word, schemaElement, false));
|
||||
result.addAll(getOneWordNatureAlias(schemaElement, false));
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
String nature = DictWordType.NATURE_SPILT + schemaElement.getModel()
|
||||
+ DictWordType.NATURE_SPILT + schemaElement.getId() + DictWordType.ENTITY.getType();
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY * 2, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* Metric DictWord
|
||||
*/
|
||||
@Service
|
||||
public class MetricWordBuilder extends BaseWordWithAliasBuilder {
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
result.add(getOneWordNature(word, schemaElement, false));
|
||||
result.addAll(getOneWordNatureAlias(schemaElement, false));
|
||||
String reverseWord = StringUtils.reverse(word);
|
||||
if (!word.equalsIgnoreCase(reverseWord)) {
|
||||
result.add(getOneWordNature(reverseWord, schemaElement, true));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.METRIC.getType();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType() + DictWordType.METRIC.getType();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* model word nature
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class ModelWordBuilder extends BaseWordWithAliasBuilder {
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
if (Objects.isNull(schemaElement)) {
|
||||
return result;
|
||||
}
|
||||
result.add(getOneWordNature(word, schemaElement, false));
|
||||
result.addAll(getOneWordNatureAlias(schemaElement, false));
|
||||
return result;
|
||||
}
|
||||
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
String nature = DictWordType.NATURE_SPILT + schemaElement.getDataSet();
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Metric DictWord
|
||||
*/
|
||||
@Service
|
||||
public class TermWordBuilder extends BaseWordWithAliasBuilder {
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
result.add(getOneWordNature(word, schemaElement, false));
|
||||
result.addAll(getOneWordNatureAlias(schemaElement, false));
|
||||
String reverseWord = StringUtils.reverse(word);
|
||||
if (!word.equalsIgnoreCase(reverseWord)) {
|
||||
result.add(getOneWordNature(reverseWord, schemaElement, true));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
Long dataSet = schemaElement.getDataSet();
|
||||
String nature = DictWordType.NATURE_SPILT + dataSet + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.TERM.getType();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + dataSet + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getType() + DictWordType.TERM.getType();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class ValueWordBuilder extends BaseWordWithAliasBuilder {
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
if (Objects.nonNull(schemaElement)) {
|
||||
result.addAll(getOneWordNatureAlias(schemaElement, false));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId();
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
dictWord.setWord(word);
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.builder;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
* DictWord Strategy Factory
|
||||
*/
|
||||
public class WordBuilderFactory {
|
||||
|
||||
private static Map<DictWordType, BaseWordBuilder> wordNatures = new ConcurrentHashMap<>();
|
||||
|
||||
static {
|
||||
wordNatures.put(DictWordType.DIMENSION, new DimensionWordBuilder());
|
||||
wordNatures.put(DictWordType.METRIC, new MetricWordBuilder());
|
||||
wordNatures.put(DictWordType.DATASET, new ModelWordBuilder());
|
||||
wordNatures.put(DictWordType.ENTITY, new EntityWordBuilder());
|
||||
wordNatures.put(DictWordType.VALUE, new ValueWordBuilder());
|
||||
wordNatures.put(DictWordType.TERM, new TermWordBuilder());
|
||||
}
|
||||
|
||||
public static BaseWordBuilder get(DictWordType strategyType) {
|
||||
return wordNatures.get(strategyType);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.file;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import java.io.FileNotFoundException;
|
||||
|
||||
|
||||
@Data
|
||||
@Configuration
|
||||
@Slf4j
|
||||
public class ChatLocalFileConfig {
|
||||
|
||||
|
||||
@Value("${dict.directory.latest:/data/dictionary/custom}")
|
||||
private String dictDirectoryLatest;
|
||||
|
||||
@Value("${dict.directory.backup:./dict/backup}")
|
||||
private String dictDirectoryBackup;
|
||||
|
||||
public String getDictDirectoryLatest() {
|
||||
return getResourceDir() + dictDirectoryLatest;
|
||||
}
|
||||
|
||||
public String getDictDirectoryBackup() {
|
||||
return dictDirectoryBackup;
|
||||
}
|
||||
|
||||
private String getResourceDir() {
|
||||
String hanlpPropertiesPath = "";
|
||||
try {
|
||||
hanlpPropertiesPath = HanlpHelper.getHanlpPropertiesPath();
|
||||
} catch (FileNotFoundException e) {
|
||||
log.warn("getResourceDir, e:", e);
|
||||
}
|
||||
return hanlpPropertiesPath;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.file;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface FileHandler {
|
||||
|
||||
/**
|
||||
* backup files to a specific directory
|
||||
* config: dict.directory.backup
|
||||
*
|
||||
* @param fileName
|
||||
*/
|
||||
void backupFile(String fileName);
|
||||
|
||||
/**
|
||||
* create a directory
|
||||
*
|
||||
* @param path
|
||||
*/
|
||||
void createDir(String path);
|
||||
|
||||
Boolean existPath(String path);
|
||||
|
||||
/**
|
||||
* write data to a specific file,
|
||||
* config dir: dict.directory.latest
|
||||
*
|
||||
* @param data
|
||||
* @param fileName
|
||||
* @param append
|
||||
*/
|
||||
void writeFile(List<String> data, String fileName, Boolean append);
|
||||
|
||||
/**
|
||||
* get the knowledge file root directory
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
String getDictRootPath();
|
||||
|
||||
/**
|
||||
* delete dictionary file
|
||||
* automatic backup
|
||||
*
|
||||
* @param fileName
|
||||
* @return
|
||||
*/
|
||||
Boolean deleteDictFile(String fileName);
|
||||
|
||||
/**
|
||||
* delete files directly without backup
|
||||
*
|
||||
* @param fileName
|
||||
*/
|
||||
void deleteFile(String fileName);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.file;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
|
||||
import java.io.BufferedWriter;
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class FileHandlerImpl implements FileHandler {
|
||||
public static final String FILE_SPILT = File.separator;
|
||||
|
||||
private final LocalFileConfig localFileConfig;
|
||||
public FileHandlerImpl(LocalFileConfig localFileConfig) {
|
||||
this.localFileConfig = localFileConfig;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void backupFile(String fileName) {
|
||||
String dictDirectoryBackup = localFileConfig.getDictDirectoryBackup();
|
||||
if (!existPath(dictDirectoryBackup)) {
|
||||
createDir(dictDirectoryBackup);
|
||||
}
|
||||
|
||||
String source = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName;
|
||||
String target = dictDirectoryBackup + FILE_SPILT + fileName;
|
||||
Path sourcePath = Paths.get(source);
|
||||
Path targetPath = Paths.get(target);
|
||||
try {
|
||||
Files.copy(sourcePath, targetPath, StandardCopyOption.REPLACE_EXISTING);
|
||||
log.info("backupFile successfully! path:{}", targetPath.toAbsolutePath());
|
||||
} catch (IOException e) {
|
||||
log.info("Failed to copy file: " + e.getMessage());
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void createDir(String directoryPath) {
|
||||
Path path = Paths.get(directoryPath);
|
||||
try {
|
||||
Files.createDirectories(path);
|
||||
log.info("Directory created successfully!");
|
||||
} catch (IOException e) {
|
||||
log.info("Failed to create directory: " + e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteFile(String filePath) {
|
||||
Path path = Paths.get(filePath);
|
||||
try {
|
||||
Files.delete(path);
|
||||
log.info("File:{} deleted successfully!", getAbsolutePath(filePath));
|
||||
} catch (IOException e) {
|
||||
log.warn("Failed to delete file:{}, e:", getAbsolutePath(filePath), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean existPath(String pathStr) {
|
||||
Path path = Paths.get(pathStr);
|
||||
if (Files.exists(path)) {
|
||||
log.info("path:{} exists!", getAbsolutePath(pathStr));
|
||||
return true;
|
||||
} else {
|
||||
log.info("path:{} not exists!", getAbsolutePath(pathStr));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeFile(List<String> lines, String fileName, Boolean append) {
|
||||
if (CollectionUtils.isEmpty(lines)) {
|
||||
log.info("lines is empty");
|
||||
return;
|
||||
}
|
||||
String dictDirectoryLatest = localFileConfig.getDictDirectoryLatest();
|
||||
if (!existPath(dictDirectoryLatest)) {
|
||||
createDir(dictDirectoryLatest);
|
||||
}
|
||||
String filePath = dictDirectoryLatest + FILE_SPILT + fileName;
|
||||
if (existPath(filePath)) {
|
||||
backupFile(fileName);
|
||||
}
|
||||
try (BufferedWriter writer = getWriter(filePath, append)) {
|
||||
if (!CollectionUtils.isEmpty(lines)) {
|
||||
for (String line : lines) {
|
||||
writer.write(line);
|
||||
writer.newLine();
|
||||
}
|
||||
}
|
||||
log.info("File:{} written successfully!", getAbsolutePath(filePath));
|
||||
} catch (IOException e) {
|
||||
log.info("Failed to write file:{}, e:", getAbsolutePath(filePath), e);
|
||||
}
|
||||
}
|
||||
|
||||
public String getAbsolutePath(String path) {
|
||||
return Paths.get(path).toAbsolutePath().toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getDictRootPath() {
|
||||
return Paths.get(localFileConfig.getDictDirectoryLatest()).toAbsolutePath().toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean deleteDictFile(String fileName) {
|
||||
backupFile(fileName);
|
||||
deleteFile(localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName);
|
||||
return true;
|
||||
}
|
||||
|
||||
private BufferedWriter getWriter(String filePath, Boolean append) throws IOException {
|
||||
if (append) {
|
||||
return Files.newBufferedWriter(Paths.get(filePath), StandardCharsets.UTF_8, StandardOpenOption.APPEND);
|
||||
}
|
||||
return Files.newBufferedWriter(Paths.get(filePath), StandardCharsets.UTF_8);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.file;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import java.io.FileNotFoundException;
|
||||
|
||||
@Data
|
||||
@Configuration
|
||||
@Slf4j
|
||||
public class LocalFileConfig {
|
||||
|
||||
|
||||
@Value("${dict.directory.latest:/data/dictionary/custom}")
|
||||
private String dictDirectoryLatest;
|
||||
|
||||
@Value("${dict.directory.backup:./data/dictionary/backup}")
|
||||
private String dictDirectoryBackup;
|
||||
|
||||
public String getDictDirectoryLatest() {
|
||||
return getDictDirectoryPrefixDir() + dictDirectoryLatest;
|
||||
}
|
||||
|
||||
public String getDictDirectoryBackup() {
|
||||
return getDictDirectoryPrefixDir() + dictDirectoryBackup;
|
||||
}
|
||||
|
||||
private String getDictDirectoryPrefixDir() {
|
||||
try {
|
||||
return HanlpHelper.getHanlpPropertiesPath();
|
||||
} catch (FileNotFoundException e) {
|
||||
log.error("getDictDirectoryPrefixDir error: ", e);
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.helper;
|
||||
|
||||
import com.hankcs.hanlp.HanLP.Config;
|
||||
import com.hankcs.hanlp.dictionary.DynamicCustomDictionary;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class FileHelper {
|
||||
|
||||
public static final String FILE_SPILT = File.separator;
|
||||
|
||||
public static void deleteCacheFile(String[] path) throws IOException {
|
||||
|
||||
String customPath = getCustomPath(path);
|
||||
File customFolder = new File(customPath);
|
||||
|
||||
File[] customSubFiles = getFileList(customFolder, ".bin");
|
||||
|
||||
for (File file : customSubFiles) {
|
||||
try {
|
||||
file.delete();
|
||||
log.info("customPath:{},delete file:{}", customPath, file);
|
||||
} catch (Exception e) {
|
||||
log.error("delete " + file, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static File[] getFileList(File customFolder, String suffix) {
|
||||
File[] customSubFiles = customFolder.listFiles(file -> {
|
||||
if (file.isDirectory()) {
|
||||
return false;
|
||||
}
|
||||
if (file.getName().toLowerCase().endsWith(suffix)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
return customSubFiles;
|
||||
}
|
||||
|
||||
private static String getCustomPath(String[] path) {
|
||||
return path[0].substring(0, path[0].lastIndexOf(FILE_SPILT)) + FILE_SPILT;
|
||||
}
|
||||
|
||||
/**
|
||||
* reset path
|
||||
*
|
||||
* @param customDictionary
|
||||
*/
|
||||
public static void resetCustomPath(DynamicCustomDictionary customDictionary) {
|
||||
String[] path = Config.CustomDictionaryPath;
|
||||
|
||||
String customPath = getCustomPath(path);
|
||||
File customFolder = new File(customPath);
|
||||
|
||||
File[] customSubFiles = getFileList(customFolder, ".txt");
|
||||
|
||||
List<String> fileList = new ArrayList<>();
|
||||
|
||||
for (File file : customSubFiles) {
|
||||
if (file.isFile()) {
|
||||
fileList.add(file.getAbsolutePath());
|
||||
}
|
||||
}
|
||||
|
||||
log.debug("CustomDictionaryPath:{}", fileList);
|
||||
Config.CustomDictionaryPath = fileList.toArray(new String[0]);
|
||||
customDictionary.path = (Config.CustomDictionaryPath == null || Config.CustomDictionaryPath.length == 0) ? path
|
||||
: Config.CustomDictionaryPath;
|
||||
if (Config.CustomDictionaryPath == null || Config.CustomDictionaryPath.length == 0) {
|
||||
Config.CustomDictionaryPath = path;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,236 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.helper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.HanLP;
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.hankcs.hanlp.dictionary.CoreDictionary;
|
||||
import com.hankcs.hanlp.dictionary.DynamicCustomDictionary;
|
||||
import com.hankcs.hanlp.seg.Segment;
|
||||
import com.hankcs.hanlp.seg.common.Term;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DictWord;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.HadoopFileIOAdapter;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.MapResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.MultiCustomDictionary;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.SearchService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import org.springframework.util.ResourceUtils;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* HanLP helper
|
||||
*/
|
||||
@Slf4j
|
||||
public class HanlpHelper {
|
||||
|
||||
public static final String FILE_SPILT = File.separator;
|
||||
public static final String SPACE_SPILT = "#";
|
||||
private static volatile DynamicCustomDictionary CustomDictionary;
|
||||
private static volatile Segment segment;
|
||||
|
||||
static {
|
||||
// reset hanlp config
|
||||
try {
|
||||
resetHanlpConfig();
|
||||
} catch (FileNotFoundException e) {
|
||||
log.error("resetHanlpConfig error", e);
|
||||
}
|
||||
}
|
||||
|
||||
public static Segment getSegment() {
|
||||
if (segment == null) {
|
||||
synchronized (HanlpHelper.class) {
|
||||
if (segment == null) {
|
||||
segment = HanLP.newSegment()
|
||||
.enableIndexMode(true).enableIndexMode(4)
|
||||
.enableCustomDictionary(true).enableCustomDictionaryForcing(true).enableOffset(true)
|
||||
.enableJapaneseNameRecognize(false).enableNameRecognize(false)
|
||||
.enableAllNamedEntityRecognize(false)
|
||||
.enableJapaneseNameRecognize(false).enableNumberQuantifierRecognize(false)
|
||||
.enablePlaceRecognize(false)
|
||||
.enableOrganizationRecognize(false).enableCustomDictionary(getDynamicCustomDictionary());
|
||||
}
|
||||
}
|
||||
}
|
||||
return segment;
|
||||
}
|
||||
|
||||
public static DynamicCustomDictionary getDynamicCustomDictionary() {
|
||||
if (CustomDictionary == null) {
|
||||
synchronized (HanlpHelper.class) {
|
||||
if (CustomDictionary == null) {
|
||||
CustomDictionary = new MultiCustomDictionary(HanLP.Config.CustomDictionaryPath);
|
||||
}
|
||||
}
|
||||
}
|
||||
return CustomDictionary;
|
||||
}
|
||||
|
||||
/***
|
||||
* reload custom dictionary
|
||||
*/
|
||||
public static boolean reloadCustomDictionary() throws IOException {
|
||||
|
||||
log.info("reloadCustomDictionary start");
|
||||
|
||||
final long startTime = System.currentTimeMillis();
|
||||
|
||||
if (HanLP.Config.CustomDictionaryPath == null || HanLP.Config.CustomDictionaryPath.length == 0) {
|
||||
return false;
|
||||
}
|
||||
if (HanLP.Config.IOAdapter instanceof HadoopFileIOAdapter) {
|
||||
// 1.delete hdfs file
|
||||
HdfsFileHelper.deleteCacheFile(HanLP.Config.CustomDictionaryPath);
|
||||
// 2.query txt files,update CustomDictionaryPath
|
||||
HdfsFileHelper.resetCustomPath(getDynamicCustomDictionary());
|
||||
} else {
|
||||
FileHelper.deleteCacheFile(HanLP.Config.CustomDictionaryPath);
|
||||
FileHelper.resetCustomPath(getDynamicCustomDictionary());
|
||||
}
|
||||
// 3.clear trie
|
||||
SearchService.clear();
|
||||
|
||||
boolean reload = getDynamicCustomDictionary().reload();
|
||||
log.info("reloadCustomDictionary end ,cost:{},reload:{}", System.currentTimeMillis() - startTime, reload);
|
||||
return reload;
|
||||
}
|
||||
|
||||
private static void resetHanlpConfig() throws FileNotFoundException {
|
||||
if (HanLP.Config.IOAdapter instanceof HadoopFileIOAdapter) {
|
||||
return;
|
||||
}
|
||||
String hanlpPropertiesPath = getHanlpPropertiesPath();
|
||||
|
||||
HanLP.Config.CustomDictionaryPath = Arrays.stream(HanLP.Config.CustomDictionaryPath)
|
||||
.map(path -> hanlpPropertiesPath + FILE_SPILT + path)
|
||||
.toArray(String[]::new);
|
||||
log.info("hanlpPropertiesPath:{},CustomDictionaryPath:{}", hanlpPropertiesPath,
|
||||
HanLP.Config.CustomDictionaryPath);
|
||||
|
||||
HanLP.Config.CoreDictionaryPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.BiGramDictionaryPath;
|
||||
HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath;
|
||||
HanLP.Config.BiGramDictionaryPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.BiGramDictionaryPath;
|
||||
HanLP.Config.CoreStopWordDictionaryPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CoreStopWordDictionaryPath;
|
||||
HanLP.Config.CoreSynonymDictionaryDictionaryPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.CoreSynonymDictionaryDictionaryPath;
|
||||
HanLP.Config.PersonDictionaryPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PersonDictionaryPath;
|
||||
HanLP.Config.PersonDictionaryTrPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PersonDictionaryTrPath;
|
||||
|
||||
HanLP.Config.PinyinDictionaryPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PinyinDictionaryPath;
|
||||
HanLP.Config.TranslatedPersonDictionaryPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.TranslatedPersonDictionaryPath;
|
||||
HanLP.Config.JapanesePersonDictionaryPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.JapanesePersonDictionaryPath;
|
||||
HanLP.Config.PlaceDictionaryPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PlaceDictionaryPath;
|
||||
HanLP.Config.PlaceDictionaryTrPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PlaceDictionaryTrPath;
|
||||
HanLP.Config.OrganizationDictionaryPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.OrganizationDictionaryPath;
|
||||
HanLP.Config.OrganizationDictionaryTrPath = hanlpPropertiesPath + FILE_SPILT
|
||||
+ HanLP.Config.OrganizationDictionaryTrPath;
|
||||
HanLP.Config.CharTypePath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CharTypePath;
|
||||
HanLP.Config.CharTablePath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CharTablePath;
|
||||
HanLP.Config.PartOfSpeechTagDictionary =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PartOfSpeechTagDictionary;
|
||||
HanLP.Config.WordNatureModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.WordNatureModelPath;
|
||||
HanLP.Config.MaxEntModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.MaxEntModelPath;
|
||||
HanLP.Config.NNParserModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.NNParserModelPath;
|
||||
HanLP.Config.PerceptronParserModelPath =
|
||||
hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PerceptronParserModelPath;
|
||||
HanLP.Config.CRFSegmentModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CRFSegmentModelPath;
|
||||
HanLP.Config.HMMSegmentModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.HMMSegmentModelPath;
|
||||
HanLP.Config.CRFCWSModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CRFCWSModelPath;
|
||||
HanLP.Config.CRFPOSModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CRFPOSModelPath;
|
||||
HanLP.Config.CRFNERModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.CRFNERModelPath;
|
||||
HanLP.Config.PerceptronCWSModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PerceptronCWSModelPath;
|
||||
HanLP.Config.PerceptronPOSModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PerceptronPOSModelPath;
|
||||
HanLP.Config.PerceptronNERModelPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.PerceptronNERModelPath;
|
||||
}
|
||||
|
||||
public static String getHanlpPropertiesPath() throws FileNotFoundException {
|
||||
return ResourceUtils.getFile("classpath:hanlp.properties").getParent();
|
||||
}
|
||||
|
||||
public static boolean addToCustomDictionary(DictWord dictWord) {
|
||||
log.debug("dictWord:{}", dictWord);
|
||||
return getDynamicCustomDictionary().insert(dictWord.getWord(), dictWord.getNatureWithFrequency());
|
||||
}
|
||||
|
||||
public static void removeFromCustomDictionary(DictWord dictWord) {
|
||||
log.debug("dictWord:{}", dictWord);
|
||||
CoreDictionary.Attribute attribute = getDynamicCustomDictionary().get(dictWord.getWord());
|
||||
if (attribute == null) {
|
||||
return;
|
||||
}
|
||||
log.info("get attribute:{}", attribute);
|
||||
getDynamicCustomDictionary().remove(dictWord.getWord());
|
||||
StringBuilder sb = new StringBuilder();
|
||||
List<Nature> natureList = new ArrayList<>();
|
||||
for (int i = 0; i < attribute.nature.length; i++) {
|
||||
if (!attribute.nature[i].toString().equals(dictWord.getNature())) {
|
||||
sb.append(attribute.nature[i].toString() + " ");
|
||||
sb.append(attribute.frequency[i] + " ");
|
||||
natureList.add((attribute.nature[i]));
|
||||
}
|
||||
}
|
||||
String natureWithFrequency = sb.toString();
|
||||
int len = natureWithFrequency.length();
|
||||
log.info("filtered natureWithFrequency:{}", natureWithFrequency);
|
||||
if (StringUtils.isNotBlank(natureWithFrequency)) {
|
||||
getDynamicCustomDictionary().add(dictWord.getWord(), natureWithFrequency.substring(0, len - 1));
|
||||
}
|
||||
SearchService.remove(dictWord, natureList.toArray(new Nature[0]));
|
||||
}
|
||||
|
||||
public static <T extends MapResult> void transLetterOriginal(List<T> mapResults) {
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
return;
|
||||
}
|
||||
for (T mapResult : mapResults) {
|
||||
if (MultiCustomDictionary.isLowerLetter(mapResult.getName())) {
|
||||
if (CustomDictionary.contains(mapResult.getName())) {
|
||||
CoreDictionary.Attribute attribute = CustomDictionary.get(mapResult.getName());
|
||||
if (attribute != null && attribute.original != null) {
|
||||
mapResult.setName(attribute.original);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static List<S2Term> getTerms(String text, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
return getSegment().seg(text.toLowerCase()).stream()
|
||||
.filter(term -> term.getNature().startsWith(DictWordType.NATURE_SPILT))
|
||||
.map(term -> transform2ApiTerm(term, modelIdToDataSetIds))
|
||||
.flatMap(Collection::stream)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static List<S2Term> transform2ApiTerm(Term term, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
List<S2Term> s2Terms = Lists.newArrayList();
|
||||
List<String> natures = NatureHelper.changeModel2DataSet(String.valueOf(term.getNature()), modelIdToDataSetIds);
|
||||
for (String nature : natures) {
|
||||
S2Term s2Term = new S2Term();
|
||||
BeanUtils.copyProperties(term, s2Term);
|
||||
s2Term.setNature(Nature.create(nature));
|
||||
s2Term.setFrequency(term.getFrequency());
|
||||
s2Terms.add(s2Term);
|
||||
}
|
||||
return s2Terms;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.helper;
|
||||
|
||||
import com.hankcs.hanlp.HanLP.Config;
|
||||
import com.hankcs.hanlp.dictionary.DynamicCustomDictionary;
|
||||
import com.hankcs.hanlp.utility.Predefine;
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.hadoop.conf.Configuration;
|
||||
import org.apache.hadoop.fs.FileStatus;
|
||||
import org.apache.hadoop.fs.FileSystem;
|
||||
import org.apache.hadoop.fs.Path;
|
||||
|
||||
/**
|
||||
* Hdfs File Helper
|
||||
*/
|
||||
@Slf4j
|
||||
public class HdfsFileHelper {
|
||||
|
||||
/***
|
||||
* delete cache file
|
||||
* @param path
|
||||
* @throws IOException
|
||||
*/
|
||||
public static void deleteCacheFile(String[] path) throws IOException {
|
||||
FileSystem fs = FileSystem.get(URI.create(path[0]), new Configuration());
|
||||
String cacheFilePath = path[0] + Predefine.BIN_EXT;
|
||||
log.info("delete cache file:{}", cacheFilePath);
|
||||
try {
|
||||
fs.delete(new Path(cacheFilePath), false);
|
||||
} catch (Exception e) {
|
||||
log.error("delete:" + cacheFilePath, e);
|
||||
}
|
||||
int customBase = cacheFilePath.lastIndexOf(FileHelper.FILE_SPILT);
|
||||
String customPath = cacheFilePath.substring(0, customBase) + FileHelper.FILE_SPILT + "*.bin";
|
||||
List<String> fileList = getFileList(fs, new Path(customPath));
|
||||
for (String file : fileList) {
|
||||
try {
|
||||
fs.delete(new Path(file), false);
|
||||
log.info("delete cache file:{}", file);
|
||||
} catch (Exception e) {
|
||||
log.error("delete " + file, e);
|
||||
}
|
||||
}
|
||||
log.info("fileList:{}", fileList);
|
||||
}
|
||||
|
||||
/**
|
||||
* reset path
|
||||
*
|
||||
* @param customDictionary
|
||||
* @throws IOException
|
||||
*/
|
||||
public static void resetCustomPath(DynamicCustomDictionary customDictionary) throws IOException {
|
||||
String[] path = Config.CustomDictionaryPath;
|
||||
FileSystem fs = FileSystem.get(URI.create(path[0]), new Configuration());
|
||||
String cacheFilePath = path[0] + Predefine.BIN_EXT;
|
||||
int customBase = cacheFilePath.lastIndexOf(FileHelper.FILE_SPILT);
|
||||
String customPath = cacheFilePath.substring(0, customBase) + FileHelper.FILE_SPILT + "*.txt";
|
||||
log.info("customPath:{}", customPath);
|
||||
List<String> fileList = getFileList(fs, new Path(customPath));
|
||||
log.info("CustomDictionaryPath:{}", fileList);
|
||||
Config.CustomDictionaryPath = fileList.toArray(new String[0]);
|
||||
customDictionary.path = (Config.CustomDictionaryPath == null || Config.CustomDictionaryPath.length == 0) ? path
|
||||
: Config.CustomDictionaryPath;
|
||||
if (Config.CustomDictionaryPath == null || Config.CustomDictionaryPath.length == 0) {
|
||||
Config.CustomDictionaryPath = path;
|
||||
}
|
||||
}
|
||||
|
||||
public static List<String> getFileList(FileSystem fs, Path folderPath) throws IOException {
|
||||
List<String> paths = new ArrayList();
|
||||
FileStatus[] fileStatuses = fs.globStatus(folderPath);
|
||||
for (int i = 0; i < fileStatuses.length; i++) {
|
||||
FileStatus fileStatus = fileStatuses[i];
|
||||
paths.add(fileStatus.getPath().toString());
|
||||
}
|
||||
return paths;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,253 @@
|
||||
package com.tencent.supersonic.headless.chat.knowledge.helper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.hankcs.hanlp.corpus.tag.Nature;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DataSetInfoStat;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* nature parse helper
|
||||
*/
|
||||
@Slf4j
|
||||
public class NatureHelper {
|
||||
|
||||
public static SchemaElementType convertToElementType(String nature) {
|
||||
DictWordType dictWordType = DictWordType.getNatureType(nature);
|
||||
if (Objects.isNull(dictWordType)) {
|
||||
return null;
|
||||
}
|
||||
SchemaElementType result = null;
|
||||
switch (dictWordType) {
|
||||
case METRIC:
|
||||
result = SchemaElementType.METRIC;
|
||||
break;
|
||||
case DIMENSION:
|
||||
result = SchemaElementType.DIMENSION;
|
||||
break;
|
||||
case ENTITY:
|
||||
result = SchemaElementType.ENTITY;
|
||||
break;
|
||||
case DATASET:
|
||||
result = SchemaElementType.DATASET;
|
||||
break;
|
||||
case VALUE:
|
||||
result = SchemaElementType.VALUE;
|
||||
break;
|
||||
case TERM:
|
||||
result = SchemaElementType.TERM;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private static boolean isDataSetOrEntity(S2Term term, Integer model) {
|
||||
return (DictWordType.NATURE_SPILT + model).equals(term.nature.toString()) || term.nature.toString()
|
||||
.endsWith(DictWordType.ENTITY.getType());
|
||||
}
|
||||
|
||||
public static Integer getDataSetByNature(Nature nature) {
|
||||
if (nature.startsWith(DictWordType.NATURE_SPILT)) {
|
||||
String[] dimensionValues = nature.toString().split(DictWordType.NATURE_SPILT);
|
||||
if (StringUtils.isNumeric(dimensionValues[1])) {
|
||||
return Integer.valueOf(dimensionValues[1]);
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
public static Long getDataSetId(String nature) {
|
||||
try {
|
||||
String[] split = nature.split(DictWordType.NATURE_SPILT);
|
||||
if (split.length <= 1) {
|
||||
return null;
|
||||
}
|
||||
return Long.valueOf(split[1]);
|
||||
} catch (NumberFormatException e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static Long getModelId(String nature) {
|
||||
try {
|
||||
String[] split = nature.split(DictWordType.NATURE_SPILT);
|
||||
if (split.length <= 1) {
|
||||
return null;
|
||||
}
|
||||
return Long.valueOf(split[1]);
|
||||
} catch (NumberFormatException e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static Nature changeModel2DataSet(String nature, Long dataSetId) {
|
||||
try {
|
||||
String[] split = nature.split(DictWordType.NATURE_SPILT);
|
||||
if (split.length <= 1) {
|
||||
return null;
|
||||
}
|
||||
split[1] = String.valueOf(dataSetId);
|
||||
return Nature.create(StringUtils.join(split, DictWordType.NATURE_SPILT));
|
||||
} catch (NumberFormatException e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static List<String> changeModel2DataSet(String nature, Map<Long, List<Long>> modelIdToDataSetIds) {
|
||||
//term prefix id is dataSetId, no need to transform
|
||||
if (SchemaElementType.TERM.equals(NatureHelper.convertToElementType(nature))) {
|
||||
return Lists.newArrayList(nature);
|
||||
}
|
||||
Long modelId = getModelId(nature);
|
||||
List<Long> dataSetIds = modelIdToDataSetIds.get(modelId);
|
||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return dataSetIds.stream().map(dataSetId -> String.valueOf(changeModel2DataSet(nature, dataSetId)))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static boolean isDimensionValueDataSetId(String nature) {
|
||||
if (StringUtils.isEmpty(nature)) {
|
||||
return false;
|
||||
}
|
||||
if (!nature.startsWith(DictWordType.NATURE_SPILT)) {
|
||||
return false;
|
||||
}
|
||||
String[] split = nature.split(DictWordType.NATURE_SPILT);
|
||||
if (split.length <= 1) {
|
||||
return false;
|
||||
}
|
||||
return !nature.endsWith(DictWordType.METRIC.getType()) && !nature.endsWith(
|
||||
DictWordType.DIMENSION.getType()) && !nature.endsWith(DictWordType.TERM.getType())
|
||||
&& StringUtils.isNumeric(split[1]);
|
||||
}
|
||||
|
||||
public static boolean isTermNature(String nature) {
|
||||
if (StringUtils.isEmpty(nature)) {
|
||||
return false;
|
||||
}
|
||||
if (!nature.startsWith(DictWordType.NATURE_SPILT)) {
|
||||
return false;
|
||||
}
|
||||
String[] split = nature.split(DictWordType.NATURE_SPILT);
|
||||
if (split.length <= 1) {
|
||||
return false;
|
||||
}
|
||||
return nature.endsWith(DictWordType.TERM.getType());
|
||||
}
|
||||
|
||||
public static DataSetInfoStat getDataSetStat(List<S2Term> terms) {
|
||||
return DataSetInfoStat.builder()
|
||||
.dataSetCount(getDataSetCount(terms))
|
||||
.dimensionDataSetCount(getDimensionCount(terms))
|
||||
.metricDataSetCount(getMetricCount(terms))
|
||||
.dimensionValueDataSetCount(getDimensionValueCount(terms))
|
||||
.build();
|
||||
}
|
||||
|
||||
private static long getDataSetCount(List<S2Term> terms) {
|
||||
return terms.stream().filter(term -> isDataSetOrEntity(term, getDataSetByNature(term.nature))).count();
|
||||
}
|
||||
|
||||
private static long getDimensionValueCount(List<S2Term> terms) {
|
||||
return terms.stream().filter(term -> isDimensionValueDataSetId(term.nature.toString())).count();
|
||||
}
|
||||
|
||||
private static long getDimensionCount(List<S2Term> terms) {
|
||||
return terms.stream().filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT) && term.nature.toString()
|
||||
.endsWith(DictWordType.DIMENSION.getType())).count();
|
||||
}
|
||||
|
||||
private static long getMetricCount(List<S2Term> terms) {
|
||||
return terms.stream().filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT) && term.nature.toString()
|
||||
.endsWith(DictWordType.METRIC.getType())).count();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the number of types of class parts of speech
|
||||
* modelId -> (nature , natureCount)
|
||||
*
|
||||
* @param terms
|
||||
* @return
|
||||
*/
|
||||
public static Map<Long, Map<DictWordType, Integer>> getDataSetToNatureStat(List<S2Term> terms) {
|
||||
Map<Long, Map<DictWordType, Integer>> modelToNature = new HashMap<>();
|
||||
terms.stream().filter(
|
||||
term -> term.nature.startsWith(DictWordType.NATURE_SPILT)
|
||||
).forEach(term -> {
|
||||
DictWordType dictWordType = DictWordType.getNatureType(String.valueOf(term.nature));
|
||||
Long model = getDataSetId(String.valueOf(term.nature));
|
||||
|
||||
Map<DictWordType, Integer> natureTypeMap = new HashMap<>();
|
||||
natureTypeMap.put(dictWordType, 1);
|
||||
|
||||
Map<DictWordType, Integer> original = modelToNature.get(model);
|
||||
if (Objects.isNull(original)) {
|
||||
modelToNature.put(model, natureTypeMap);
|
||||
} else {
|
||||
Integer count = original.get(dictWordType);
|
||||
if (Objects.isNull(count)) {
|
||||
count = 1;
|
||||
} else {
|
||||
count = count + 1;
|
||||
}
|
||||
original.put(dictWordType, count);
|
||||
}
|
||||
});
|
||||
return modelToNature;
|
||||
}
|
||||
|
||||
public static List<Long> selectPossibleDataSets(List<S2Term> terms) {
|
||||
Map<Long, Map<DictWordType, Integer>> modelToNatureStat = getDataSetToNatureStat(terms);
|
||||
Integer maxDataSetTypeSize = modelToNatureStat.entrySet().stream()
|
||||
.max(Comparator.comparingInt(o -> o.getValue().size())).map(entry -> entry.getValue().size())
|
||||
.orElse(null);
|
||||
if (Objects.isNull(maxDataSetTypeSize) || maxDataSetTypeSize == 0) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return modelToNatureStat.entrySet().stream().filter(entry -> entry.getValue().size() == maxDataSetTypeSize)
|
||||
.map(entry -> entry.getKey()).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static Long getElementID(String nature) {
|
||||
String[] split = nature.split(DictWordType.NATURE_SPILT);
|
||||
if (split.length >= 3) {
|
||||
return Long.valueOf(split[2]);
|
||||
}
|
||||
return 0L;
|
||||
}
|
||||
|
||||
public static Set<Long> getModelIds(Map<Long, List<Long>> modelIdToDataSetIds, Set<Long> detectDataSetIds) {
|
||||
Set<Long> detectModelIds = modelIdToDataSetIds.keySet();
|
||||
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||
detectModelIds = modelIdToDataSetIds.entrySet().stream().filter(entry -> {
|
||||
List<Long> dataSetIds = entry.getValue().stream().filter(detectDataSetIds::contains)
|
||||
.collect(Collectors.toList());
|
||||
if (!CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}).map(entry -> entry.getKey()).collect(Collectors.toSet());
|
||||
}
|
||||
return detectModelIds;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.HashSet;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Predicate;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public abstract class BaseMapper implements SchemaMapper {
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
|
||||
String simpleName = this.getClass().getSimpleName();
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.debug("before {},mapInfo:{}", simpleName,
|
||||
queryContext.getMapInfo().getDataSetElementMatches());
|
||||
|
||||
try {
|
||||
doMap(queryContext);
|
||||
filter(queryContext);
|
||||
} catch (Exception e) {
|
||||
log.error("work error", e);
|
||||
}
|
||||
|
||||
long cost = System.currentTimeMillis() - startTime;
|
||||
log.debug("after {},cost:{},mapInfo:{}", simpleName, cost,
|
||||
queryContext.getMapInfo().getDataSetElementMatches());
|
||||
}
|
||||
|
||||
private void filter(QueryContext queryContext) {
|
||||
filterByDataSetId(queryContext);
|
||||
filterTermByDetectWordLen(queryContext);
|
||||
switch (queryContext.getQueryDataType()) {
|
||||
case TAG:
|
||||
filterByQueryDataType(queryContext, element -> !(element.getIsTag() > 0));
|
||||
break;
|
||||
case METRIC:
|
||||
filterByQueryDataType(queryContext, element -> !SchemaElementType.METRIC.equals(element.getType()));
|
||||
break;
|
||||
case DIMENSION:
|
||||
filterByQueryDataType(queryContext, element -> {
|
||||
boolean isDimensionOrValue = SchemaElementType.DIMENSION.equals(element.getType())
|
||||
|| SchemaElementType.VALUE.equals(element.getType());
|
||||
return !isDimensionOrValue;
|
||||
});
|
||||
break;
|
||||
case ALL:
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private static void filterByDataSetId(QueryContext queryContext) {
|
||||
Set<Long> dataSetIds = queryContext.getDataSetIds();
|
||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return;
|
||||
}
|
||||
Set<Long> dataSetIdInMapInfo = new HashSet<>(queryContext.getMapInfo().getDataSetElementMatches().keySet());
|
||||
for (Long dataSetId : dataSetIdInMapInfo) {
|
||||
if (!dataSetIds.contains(dataSetId)) {
|
||||
queryContext.getMapInfo().getDataSetElementMatches().remove(dataSetId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void filterTermByDetectWordLen(QueryContext queryContext) {
|
||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
|
||||
queryContext.getMapInfo().getDataSetElementMatches();
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
|
||||
List<SchemaElementMatch> value = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(value)) {
|
||||
value.removeIf(schemaElementMatch -> {
|
||||
if (!SchemaElementType.TERM.equals(schemaElementMatch.getElement().getType())) {
|
||||
return false;
|
||||
}
|
||||
return StringUtils.length(schemaElementMatch.getDetectWord()) <= 1;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static void filterByQueryDataType(QueryContext queryContext, Predicate<SchemaElement> needRemovePredicate) {
|
||||
queryContext.getMapInfo().getDataSetElementMatches().values().stream().forEach(
|
||||
schemaElementMatches -> schemaElementMatches.removeIf(
|
||||
schemaElementMatch -> {
|
||||
SchemaElement element = schemaElementMatch.getElement();
|
||||
SchemaElementType type = element.getType();
|
||||
if (SchemaElementType.ENTITY.equals(type) || SchemaElementType.DATASET.equals(type)
|
||||
|| SchemaElementType.ID.equals(type)) {
|
||||
return false;
|
||||
}
|
||||
return needRemovePredicate.test(element);
|
||||
}
|
||||
));
|
||||
}
|
||||
|
||||
public abstract void doMap(QueryContext queryContext);
|
||||
|
||||
public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId, SchemaElementMatch newElementMatch) {
|
||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = schemaMap.getDataSetElementMatches();
|
||||
List<SchemaElementMatch> schemaElementMatches = dataSetElementMatches.putIfAbsent(dataSetId, new ArrayList<>());
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = dataSetElementMatches.get(dataSetId);
|
||||
}
|
||||
//remove duplication
|
||||
AtomicBoolean needAddNew = new AtomicBoolean(true);
|
||||
schemaElementMatches.removeIf(
|
||||
existElementMatch -> {
|
||||
if (isEquals(existElementMatch, newElementMatch)) {
|
||||
if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) {
|
||||
return true;
|
||||
} else {
|
||||
needAddNew.set(false);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
);
|
||||
if (needAddNew.get()) {
|
||||
schemaElementMatches.add(newElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean isEquals(SchemaElementMatch existElementMatch, SchemaElementMatch newElementMatch) {
|
||||
SchemaElement existElement = existElementMatch.getElement();
|
||||
SchemaElement newElement = newElementMatch.getElement();
|
||||
if (!existElement.equals(newElement)) {
|
||||
return false;
|
||||
}
|
||||
if (SchemaElementType.VALUE.equals(newElement.getType())) {
|
||||
return existElementMatch.getWord().equalsIgnoreCase(newElementMatch.getWord());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
public SchemaElement getSchemaElement(Long dataSetId, SchemaElementType elementType, Long elementID,
|
||||
SemanticSchema semanticSchema) {
|
||||
SchemaElement element = new SchemaElement();
|
||||
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
||||
if (Objects.isNull(dataSetSchema)) {
|
||||
return null;
|
||||
}
|
||||
SchemaElement elementDb = dataSetSchema.getElement(elementType, elementID);
|
||||
if (Objects.isNull(elementDb)) {
|
||||
return null;
|
||||
}
|
||||
BeanUtils.copyProperties(elementDb, element);
|
||||
element.setAlias(getAlias(elementDb));
|
||||
return element;
|
||||
}
|
||||
|
||||
public List<String> getAlias(SchemaElement element) {
|
||||
if (!SchemaElementType.VALUE.equals(element.getType())) {
|
||||
return element.getAlias();
|
||||
}
|
||||
if (org.apache.commons.collections.CollectionUtils.isNotEmpty(element.getAlias()) && StringUtils.isNotEmpty(
|
||||
element.getName())) {
|
||||
return element.getAlias().stream()
|
||||
.filter(aliasItem -> aliasItem.contains(element.getName()))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
return element.getAlias();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
||||
|
||||
@Autowired
|
||||
protected MapperHelper mapperHelper;
|
||||
|
||||
@Autowired
|
||||
protected MapperConfig mapperConfig;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("terms:{},,detectDataSetIds:{}", terms, detectDataSetIds);
|
||||
|
||||
List<T> detects = detect(queryContext, terms, detectDataSetIds);
|
||||
Map<MatchText, List<T>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
public List<T> detect(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds) {
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(terms);
|
||||
String text = queryContext.getQueryText();
|
||||
Set<T> results = new HashSet<>();
|
||||
|
||||
Set<String> detectSegments = new HashSet<>();
|
||||
|
||||
for (Integer startIndex = 0; startIndex <= text.length() - 1; ) {
|
||||
|
||||
for (Integer index = startIndex; index <= text.length(); ) {
|
||||
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||
if (index <= text.length()) {
|
||||
String detectSegment = text.substring(startIndex, index).trim();
|
||||
detectSegments.add(detectSegment);
|
||||
detectByStep(queryContext, results, detectDataSetIds, detectSegment, offset);
|
||||
}
|
||||
}
|
||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||
}
|
||||
detectByBatch(queryContext, results, detectDataSetIds, detectSegments);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected void detectByBatch(QueryContext queryContext, Set<T> results, Set<Long> detectDataSetIds,
|
||||
Set<String> detectSegments) {
|
||||
}
|
||||
|
||||
public Map<Integer, Integer> getRegOffsetToLength(List<S2Term> terms) {
|
||||
return terms.stream().sorted(Comparator.comparing(S2Term::length))
|
||||
.collect(Collectors.toMap(S2Term::getOffset, term -> term.word.length(),
|
||||
(value1, value2) -> value2));
|
||||
}
|
||||
|
||||
public void selectResultInOneRound(Set<T> existResults, List<T> oneRoundResults) {
|
||||
if (CollectionUtils.isEmpty(oneRoundResults)) {
|
||||
return;
|
||||
}
|
||||
for (T oneRoundResult : oneRoundResults) {
|
||||
if (existResults.contains(oneRoundResult)) {
|
||||
boolean isDeleted = existResults.removeIf(
|
||||
existResult -> {
|
||||
boolean delete = needDelete(oneRoundResult, existResult);
|
||||
if (delete) {
|
||||
log.info("deleted existResult:{}", existResult);
|
||||
}
|
||||
return delete;
|
||||
}
|
||||
);
|
||||
if (isDeleted) {
|
||||
log.info("deleted, add oneRoundResult:{}", oneRoundResult);
|
||||
existResults.add(oneRoundResult);
|
||||
}
|
||||
} else {
|
||||
existResults.add(oneRoundResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public List<T> getMatches(QueryContext queryContext, List<S2Term> terms) {
|
||||
Set<Long> dataSetIds = queryContext.getDataSetIds();
|
||||
terms = filterByDataSetId(terms, dataSetIds);
|
||||
Map<MatchText, List<T>> matchResult = match(queryContext, terms, dataSetIds);
|
||||
List<T> matches = new ArrayList<>();
|
||||
if (Objects.isNull(matchResult)) {
|
||||
return matches;
|
||||
}
|
||||
Optional<List<T>> first = matchResult.entrySet().stream()
|
||||
.filter(entry -> CollectionUtils.isNotEmpty(entry.getValue()))
|
||||
.map(entry -> entry.getValue()).findFirst();
|
||||
|
||||
if (first.isPresent()) {
|
||||
matches = first.get();
|
||||
}
|
||||
return matches;
|
||||
}
|
||||
|
||||
public List<S2Term> filterByDataSetId(List<S2Term> terms, Set<Long> dataSetIds) {
|
||||
logTerms(terms);
|
||||
if (CollectionUtils.isNotEmpty(dataSetIds)) {
|
||||
terms = terms.stream().filter(term -> {
|
||||
Long dataSetId = NatureHelper.getDataSetId(term.getNature().toString());
|
||||
if (Objects.nonNull(dataSetId)) {
|
||||
return dataSetIds.contains(dataSetId);
|
||||
}
|
||||
return false;
|
||||
}).collect(Collectors.toList());
|
||||
log.info("terms filter by dataSetId:{}", dataSetIds);
|
||||
logTerms(terms);
|
||||
}
|
||||
return terms;
|
||||
}
|
||||
|
||||
public void logTerms(List<S2Term> terms) {
|
||||
if (CollectionUtils.isEmpty(terms)) {
|
||||
return;
|
||||
}
|
||||
for (S2Term term : terms) {
|
||||
log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency());
|
||||
}
|
||||
}
|
||||
|
||||
public abstract boolean needDelete(T oneRoundResult, T existResult);
|
||||
|
||||
public abstract String getMapKey(T a);
|
||||
|
||||
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset);
|
||||
|
||||
public double getThreshold(Double threshold, Double minThreshold, MapModeEnum mapModeEnum) {
|
||||
double decreaseAmount = (threshold - minThreshold) / 4;
|
||||
double divideThreshold = threshold - mapModeEnum.threshold * decreaseAmount;
|
||||
return divideThreshold >= minThreshold ? divideThreshold : minThreshold;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* DatabaseMatchStrategy uses SQL LIKE operator to match schema elements.
|
||||
* It currently supports fuzzy matching against names and aliases.
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult> {
|
||||
|
||||
private List<SchemaElement> allElements;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<DatabaseMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
this.allElements = getSchemaElements(queryContext);
|
||||
return super.match(queryContext, terms, detectDataSetIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean needDelete(DatabaseMapResult oneRoundResult, DatabaseMapResult existResult) {
|
||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMapKey(DatabaseMapResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + a.getSchemaElement().getId()
|
||||
+ Constants.UNDERLINE + a.getSchemaElement().getName();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
if (StringUtils.isBlank(detectSegment)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||
|
||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||
String name = entry.getKey();
|
||||
if (!name.contains(detectSegment)
|
||||
|| mapperHelper.getSimilarity(detectSegment, name) < metricDimensionThresholdConfig) {
|
||||
continue;
|
||||
}
|
||||
Set<SchemaElement> schemaElements = entry.getValue();
|
||||
if (!CollectionUtils.isEmpty(detectDataSetIds)) {
|
||||
schemaElements = schemaElements.stream()
|
||||
.filter(schemaElement -> detectDataSetIds.contains(schemaElement.getDataSet()))
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
DatabaseMapResult databaseMapResult = new DatabaseMapResult();
|
||||
databaseMapResult.setDetectWord(detectSegment);
|
||||
databaseMapResult.setName(schemaElement.getName());
|
||||
databaseMapResult.setSchemaElement(schemaElement);
|
||||
existResults.add(databaseMapResult);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private List<SchemaElement> getSchemaElements(QueryContext queryContext) {
|
||||
List<SchemaElement> allElements = new ArrayList<>();
|
||||
allElements.addAll(queryContext.getSemanticSchema().getDimensions());
|
||||
allElements.addAll(queryContext.getSemanticSchema().getMetrics());
|
||||
return allElements;
|
||||
}
|
||||
|
||||
private Double getThreshold(QueryContext queryContext) {
|
||||
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
||||
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
|
||||
|
||||
Map<Long, List<SchemaElementMatch>> modelElementMatches = queryContext.getMapInfo().getDataSetElementMatches();
|
||||
|
||||
boolean existElement = modelElementMatches.entrySet().stream().anyMatch(entry -> entry.getValue().size() >= 1);
|
||||
|
||||
if (!existElement) {
|
||||
threshold = threshold / 2;
|
||||
log.debug("ModelElementMatches:{},not exist Element threshold reduce by half:{}",
|
||||
modelElementMatches, threshold);
|
||||
}
|
||||
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
|
||||
}
|
||||
|
||||
private Map<String, Set<SchemaElement>> getNameToItems(List<SchemaElement> models) {
|
||||
return models.stream().collect(
|
||||
Collectors.toMap(SchemaElement::getName, a -> {
|
||||
Set<SchemaElement> result = new HashSet<>();
|
||||
result.add(a);
|
||||
return result;
|
||||
}, (k1, k2) -> {
|
||||
k1.addAll(k2);
|
||||
return k1;
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/***
|
||||
* A mapper that recognizes schema elements with vector embedding.
|
||||
*/
|
||||
@Slf4j
|
||||
public class EmbeddingMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
//1. query from embedding by queryText
|
||||
String queryText = queryContext.getQueryText();
|
||||
List<S2Term> terms = HanlpHelper.getTerms(queryText, queryContext.getModelIdToDataSetIds());
|
||||
|
||||
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
|
||||
List<EmbeddingResult> matchResults = matchStrategy.getMatches(queryContext, terms);
|
||||
|
||||
HanlpHelper.transLetterOriginal(matchResults);
|
||||
|
||||
//2. build SchemaElementMatch by info
|
||||
for (EmbeddingResult matchResult : matchResults) {
|
||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||
Long dataSetId = Retrieval.getLongId(matchResult.getMetadata().get("dataSetId"));
|
||||
if (Objects.isNull(dataSetId)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||
SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId,
|
||||
queryContext.getSemanticSchema());
|
||||
if (schemaElement == null) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.word(matchResult.getName())
|
||||
.similarity(1 - matchResult.getDistance())
|
||||
.detectWord(matchResult.getDetectWord())
|
||||
.build();
|
||||
//3. add to mapInfo
|
||||
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER;
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER;
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_THRESHOLD;
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_THRESHOLD_MIN;
|
||||
|
||||
/**
|
||||
* EmbeddingMatchStrategy uses vector database to perform
|
||||
* similarity search against the embeddings of schema elements.
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
||||
|
||||
@Autowired
|
||||
private MetaEmbeddingService metaEmbeddingService;
|
||||
|
||||
@Override
|
||||
public boolean needDelete(EmbeddingResult oneRoundResult, EmbeddingResult existResult) {
|
||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||
&& existResult.getDistance() > oneRoundResult.getDistance();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMapKey(EmbeddingResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + a.getId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults,
|
||||
Set<Long> detectDataSetIds, String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results,
|
||||
Set<Long> detectDataSetIds, Set<String> detectSegments) {
|
||||
int embedddingMapperMin = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MIN));
|
||||
int embedddingMapperMax = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_MAX));
|
||||
int embeddingMapperBatch = Integer.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
|
||||
|
||||
List<String> queryTextsList = detectSegments.stream()
|
||||
.map(detectSegment -> detectSegment.trim())
|
||||
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)
|
||||
&& detectSegment.length() >= embedddingMapperMin
|
||||
&& detectSegment.length() <= embedddingMapperMax)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
List<List<String>> queryTextsSubList = Lists.partition(queryTextsList,
|
||||
embeddingMapperBatch);
|
||||
|
||||
for (List<String> queryTextsSub : queryTextsSubList) {
|
||||
detectByQueryTextsSub(results, detectDataSetIds, queryTextsSub, queryContext);
|
||||
}
|
||||
}
|
||||
|
||||
private void detectByQueryTextsSub(Set<EmbeddingResult> results, Set<Long> detectDataSetIds,
|
||||
List<String> queryTextsSub, QueryContext queryContext) {
|
||||
Map<Long, List<Long>> modelIdToDataSetIds = queryContext.getModelIdToDataSetIds();
|
||||
double embeddingThreshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
|
||||
double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
|
||||
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, queryContext.getMapModeEnum());
|
||||
|
||||
// step1. build query params
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
|
||||
|
||||
// step2. retrieveQuery by detectSegment
|
||||
int embeddingNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
|
||||
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
|
||||
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
|
||||
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
}
|
||||
// step3. build EmbeddingResults
|
||||
List<EmbeddingResult> collect = retrieveQueryResults.stream()
|
||||
.map(retrieveQueryResult -> {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
if (CollectionUtils.isNotEmpty(retrievals)) {
|
||||
retrievals.removeIf(retrieval -> {
|
||||
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) {
|
||||
return retrieval.getDistance() > 1 - threshold;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
return retrieveQueryResult;
|
||||
})
|
||||
.filter(retrieveQueryResult -> CollectionUtils.isNotEmpty(retrieveQueryResult.getRetrieval()))
|
||||
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
|
||||
.map(retrieval -> {
|
||||
EmbeddingResult embeddingResult = new EmbeddingResult();
|
||||
BeanUtils.copyProperties(retrieval, embeddingResult);
|
||||
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
|
||||
embeddingResult.setName(retrieval.getQuery());
|
||||
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toString()));
|
||||
embeddingResult.setMetadata(convertedMap);
|
||||
return embeddingResult;
|
||||
}))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// step4. select mapResul in one round
|
||||
int embeddingRoundNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
|
||||
int roundNumber = embeddingRoundNumber * queryTextsSub.size();
|
||||
List<EmbeddingResult> oneRoundResults = collect.stream()
|
||||
.sorted(Comparator.comparingDouble(EmbeddingResult::getDistance))
|
||||
.limit(roundNumber)
|
||||
.collect(Collectors.toList());
|
||||
selectResultInOneRound(results, oneRoundResults);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* A mapper capable of converting the VALUE of entity dimension values into ID types.
|
||||
*/
|
||||
@Slf4j
|
||||
public class EntityMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
for (Long dataSetId : schemaMapInfo.getMatchedDataSetInfos()) {
|
||||
List<SchemaElementMatch> schemaElementMatchList = schemaMapInfo.getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(schemaElementMatchList)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElement entity = getEntity(dataSetId, queryContext);
|
||||
if (entity == null || entity.getId() == null) {
|
||||
continue;
|
||||
}
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
|
||||
.filter(schemaElementMatch -> SchemaElementType.VALUE.equals(
|
||||
schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
||||
continue;
|
||||
}
|
||||
if (!checkExistSameEntitySchemaElements(schemaElementMatch, schemaElementMatchList)) {
|
||||
SchemaElementMatch entitySchemaElementMath = new SchemaElementMatch();
|
||||
BeanUtils.copyProperties(schemaElementMatch, entitySchemaElementMath);
|
||||
entitySchemaElementMath.setElement(entity);
|
||||
schemaElementMatchList.add(entitySchemaElementMath);
|
||||
}
|
||||
schemaElementMatch.getElement().setType(SchemaElementType.ID);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch,
|
||||
List<SchemaElementMatch> schemaElementMatchList) {
|
||||
List<SchemaElementMatch> entitySchemaElements = schemaElementMatchList.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.ENTITY.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : entitySchemaElements) {
|
||||
if (schemaElementMatch.getElement().getId().equals(valueSchemaElementMatch.getElement().getId())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private SchemaElement getEntity(Long dataSetId, QueryContext queryContext) {
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
DataSetSchema modelSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
||||
if (modelSchema != null && modelSchema.getEntity() != null) {
|
||||
return modelSchema.getEntity();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,140 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.MAPPER_DETECTION_MAX_SIZE;
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.MAPPER_DETECTION_SIZE;
|
||||
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.MAPPER_DIMENSION_VALUE_SIZE;
|
||||
|
||||
/**
|
||||
* HanlpDictMatchStrategy uses <a href="https://www.hanlp.com/">HanLP</a> to
|
||||
* match schema elements. It currently supports prefix and suffix matching
|
||||
* against names, values and aliases.
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
@Autowired
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
log.debug("terms:{},detectModelIds:{}", terms, detectDataSetIds);
|
||||
|
||||
List<HanlpMapResult> detects = detect(queryContext, terms, detectDataSetIds);
|
||||
Map<MatchText, List<HanlpMapResult>> result = new HashMap<>();
|
||||
|
||||
result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
||||
return getMapKey(oneRoundResult).equals(getMapKey(existResult))
|
||||
&& existResult.getDetectWord().length() < oneRoundResult.getDetectWord().length();
|
||||
}
|
||||
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
// step1. pre search
|
||||
Integer oneDetectionMaxSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_MAX_SIZE));
|
||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
// step2. suffix search
|
||||
LinkedHashSet<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(detectSegment,
|
||||
oneDetectionMaxSize, queryContext.getModelIdToDataSetIds(), detectDataSetIds)
|
||||
.stream().collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
|
||||
if (CollectionUtils.isEmpty(hanlpMapResults)) {
|
||||
return;
|
||||
}
|
||||
// step3. merge pre/suffix result
|
||||
hanlpMapResults = hanlpMapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step4. filter by similarity
|
||||
hanlpMapResults = hanlpMapResults.stream()
|
||||
.filter(term -> mapperHelper.getSimilarity(detectSegment, term.getName())
|
||||
>= getThresholdMatch(term.getNatures(), queryContext))
|
||||
.filter(term -> CollectionUtils.isNotEmpty(term.getNatures()))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
log.debug("detectSegment:{},after isSimilarity parseResults:{}", detectSegment, hanlpMapResults);
|
||||
|
||||
hanlpMapResults = hanlpMapResults.stream().map(parseResult -> {
|
||||
parseResult.setOffset(offset);
|
||||
parseResult.setSimilarity(mapperHelper.getSimilarity(detectSegment, parseResult.getName()));
|
||||
return parseResult;
|
||||
}).collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
|
||||
// step5. take only M dimensionValue or N-M metric/dimension value per rond.
|
||||
int oneDetectionValueSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE));
|
||||
List<HanlpMapResult> dimensionValues = hanlpMapResults.stream()
|
||||
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
|
||||
.limit(oneDetectionValueSize)
|
||||
.collect(Collectors.toList());
|
||||
|
||||
Integer oneDetectionSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_SIZE));
|
||||
List<HanlpMapResult> oneRoundResults = new ArrayList<>();
|
||||
|
||||
// add the dimensionValue if it exists
|
||||
if (CollectionUtils.isNotEmpty(dimensionValues)) {
|
||||
oneRoundResults.addAll(dimensionValues);
|
||||
}
|
||||
// fill the rest of the list with other results, excluding the dimensionValue if it was added
|
||||
if (oneRoundResults.size() < oneDetectionSize) {
|
||||
List<HanlpMapResult> additionalResults = hanlpMapResults.stream()
|
||||
.filter(entry -> !mapperHelper.existDimensionValues(entry.getNatures())
|
||||
&& !oneRoundResults.contains(entry))
|
||||
.limit(oneDetectionSize - oneRoundResults.size())
|
||||
.collect(Collectors.toList());
|
||||
oneRoundResults.addAll(additionalResults);
|
||||
}
|
||||
// step6. select mapResul in one round
|
||||
selectResultInOneRound(existResults, oneRoundResults);
|
||||
}
|
||||
|
||||
public String getMapKey(HanlpMapResult a) {
|
||||
return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures());
|
||||
}
|
||||
|
||||
public double getThresholdMatch(List<String> natures, QueryContext queryContext) {
|
||||
Double threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD));
|
||||
Double minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN));
|
||||
if (mapperHelper.existDimensionValues(natures)) {
|
||||
threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD));
|
||||
minThreshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD_MIN));
|
||||
}
|
||||
|
||||
return getThreshold(threshold, minThreshold, queryContext.getMapModeEnum());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/***
|
||||
* A mapper that recognizes schema elements with keyword.
|
||||
* It leverages two matching strategies: HanlpDictMatchStrategy and DatabaseMatchStrategy.
|
||||
*/
|
||||
@Slf4j
|
||||
public class KeywordMapper extends BaseMapper {
|
||||
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
String queryText = queryContext.getQueryText();
|
||||
//1.hanlpDict Match
|
||||
List<S2Term> terms = HanlpHelper.getTerms(queryText, queryContext.getModelIdToDataSetIds());
|
||||
HanlpDictMatchStrategy hanlpMatchStrategy = ContextUtils.getBean(HanlpDictMatchStrategy.class);
|
||||
|
||||
List<HanlpMapResult> hanlpMapResults = hanlpMatchStrategy.getMatches(queryContext, terms);
|
||||
convertHanlpMapResultToMapInfo(hanlpMapResults, queryContext, terms);
|
||||
|
||||
//2.database Match
|
||||
DatabaseMatchStrategy databaseMatchStrategy = ContextUtils.getBean(DatabaseMatchStrategy.class);
|
||||
|
||||
List<DatabaseMapResult> databaseResults = databaseMatchStrategy.getMatches(queryContext, terms);
|
||||
convertDatabaseMapResultToMapInfo(queryContext, databaseResults);
|
||||
}
|
||||
|
||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
|
||||
List<S2Term> terms) {
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
return;
|
||||
}
|
||||
HanlpHelper.transLetterOriginal(mapResults);
|
||||
Map<String, Long> wordNatureToFrequency = terms.stream().collect(
|
||||
Collectors.toMap(entry -> entry.getWord() + entry.getNature(),
|
||||
term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2));
|
||||
|
||||
for (HanlpMapResult hanlpMapResult : mapResults) {
|
||||
for (String nature : hanlpMapResult.getNatures()) {
|
||||
Long dataSetId = NatureHelper.getDataSetId(nature);
|
||||
if (Objects.isNull(dataSetId)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementType elementType = NatureHelper.convertToElementType(nature);
|
||||
if (Objects.isNull(elementType)) {
|
||||
continue;
|
||||
}
|
||||
Long elementID = NatureHelper.getElementID(nature);
|
||||
SchemaElement element = getSchemaElement(dataSetId, elementType,
|
||||
elementID, queryContext.getSemanticSchema());
|
||||
if (element == null) {
|
||||
continue;
|
||||
}
|
||||
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
.frequency(frequency)
|
||||
.word(hanlpMapResult.getName())
|
||||
.similarity(hanlpMapResult.getSimilarity())
|
||||
.detectWord(hanlpMapResult.getDetectWord())
|
||||
.build();
|
||||
|
||||
addToSchemaMap(queryContext.getMapInfo(), dataSetId, schemaElementMatch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void convertDatabaseMapResultToMapInfo(QueryContext queryContext, List<DatabaseMapResult> mapResults) {
|
||||
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
|
||||
for (DatabaseMapResult match : mapResults) {
|
||||
SchemaElement schemaElement = match.getSchemaElement();
|
||||
Set<Long> regElementSet = getRegElementSet(queryContext.getMapInfo(), schemaElement);
|
||||
if (regElementSet.contains(schemaElement.getId())) {
|
||||
continue;
|
||||
}
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(schemaElement)
|
||||
.word(schemaElement.getName())
|
||||
.detectWord(match.getDetectWord())
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||
addToSchemaMap(queryContext.getMapInfo(), schemaElement.getDataSet(), schemaElementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
private Set<Long> getRegElementSet(SchemaMapInfo schemaMap, SchemaElement schemaElement) {
|
||||
List<SchemaElementMatch> elements = schemaMap.getMatchedElements(schemaElement.getDataSet());
|
||||
if (CollectionUtils.isEmpty(elements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
return elements.stream()
|
||||
.filter(elementMatch ->
|
||||
SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|
||||
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.getElement().getId())
|
||||
.collect(Collectors.toSet());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.ParameterConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service("HeadlessMapperConfig")
|
||||
public class MapperConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter MAPPER_DETECTION_SIZE =
|
||||
new Parameter("s2.mapper.detection.size", "8",
|
||||
"一次探测返回结果个数",
|
||||
"在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_DETECTION_MAX_SIZE =
|
||||
new Parameter("s2.mapper.detection.max.size", "20",
|
||||
"一次探测前后缀匹配结果返回个数",
|
||||
"单次前后缀匹配返回的结果个数",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_NAME_THRESHOLD =
|
||||
new Parameter("s2.mapper.name.threshold", "0.3",
|
||||
"指标名、维度名文本相似度阈值",
|
||||
"文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_NAME_THRESHOLD_MIN =
|
||||
new Parameter("s2.mapper.name.min.threshold", "0.25",
|
||||
"指标名、维度名最小文本相似度阈值",
|
||||
"指标名、维度名相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_DIMENSION_VALUE_SIZE =
|
||||
new Parameter("s2.mapper.value.size", "1",
|
||||
"一次探测返回维度值结果个数",
|
||||
"在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的维度值结果个数",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_VALUE_THRESHOLD =
|
||||
new Parameter("s2.mapper.value.threshold", "0.5",
|
||||
"维度值文本相似度阈值",
|
||||
"文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter MAPPER_VALUE_THRESHOLD_MIN =
|
||||
new Parameter("s2.mapper.value.min.threshold", "0.3",
|
||||
"维度值最小文本相似度阈值",
|
||||
"维度值相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_MIN =
|
||||
new Parameter("s2.mapper.embedding.word.min", "4",
|
||||
"用于向量召回最小的文本长度",
|
||||
"为提高向量召回效率, 小于该长度的文本不进行向量语义召回",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_MAX =
|
||||
new Parameter("s2.mapper.embedding.word.max", "5",
|
||||
"用于向量召回最大的文本长度",
|
||||
"为提高向量召回效率, 大于该长度的文本不进行向量语义召回",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_BATCH =
|
||||
new Parameter("s2.mapper.embedding.batch", "50",
|
||||
"批量向量召回文本请求个数",
|
||||
"每次进行向量语义召回的原始文本片段个数",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_NUMBER =
|
||||
new Parameter("s2.mapper.embedding.number", "5",
|
||||
"批量向量召回文本返回结果个数",
|
||||
"每个文本进行向量语义召回的文本结果个数",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_THRESHOLD =
|
||||
new Parameter("s2.mapper.embedding.threshold", "0.99",
|
||||
"向量召回相似度阈值",
|
||||
"相似度小于该阈值的则舍弃",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_THRESHOLD_MIN =
|
||||
new Parameter("s2.mapper.embedding.min.threshold", "0.9",
|
||||
"向量召回最小相似度阈值",
|
||||
"向量召回相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
|
||||
new Parameter("s2.mapper.embedding.round.number", "10",
|
||||
"向量召回最小相似度阈值",
|
||||
"向量召回相似度阈值在动态调整中的最低值",
|
||||
"number", "Mapper相关配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
MAPPER_DETECTION_SIZE,
|
||||
MAPPER_DETECTION_MAX_SIZE,
|
||||
MAPPER_NAME_THRESHOLD,
|
||||
MAPPER_NAME_THRESHOLD_MIN,
|
||||
MAPPER_DIMENSION_VALUE_SIZE,
|
||||
MAPPER_VALUE_THRESHOLD,
|
||||
MAPPER_VALUE_THRESHOLD_MIN
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.hankcs.hanlp.algorithm.EditDistance;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
@Service
|
||||
@Slf4j
|
||||
public class MapperHelper {
|
||||
|
||||
public Integer getStepIndex(Map<Integer, Integer> regOffsetToLength, Integer index) {
|
||||
Integer subRegLength = regOffsetToLength.get(index);
|
||||
if (Objects.nonNull(subRegLength)) {
|
||||
index = index + subRegLength;
|
||||
} else {
|
||||
index++;
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
public Integer getStepOffset(List<S2Term> termList, Integer index) {
|
||||
List<Integer> offsetList = termList.stream().sorted(Comparator.comparing(S2Term::getOffset))
|
||||
.map(term -> term.getOffset()).collect(Collectors.toList());
|
||||
|
||||
for (int j = 0; j < termList.size() - 1; j++) {
|
||||
if (offsetList.get(j) <= index && offsetList.get(j + 1) > index) {
|
||||
return offsetList.get(j);
|
||||
}
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
/***
|
||||
* exist dimension values
|
||||
* @param natures
|
||||
* @return
|
||||
*/
|
||||
public boolean existDimensionValues(List<String> natures) {
|
||||
for (String nature : natures) {
|
||||
if (NatureHelper.isDimensionValueDataSetId(nature)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public boolean existTerms(List<String> natures) {
|
||||
for (String nature : natures) {
|
||||
if (NatureHelper.isTermNature(nature)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/***
|
||||
* get similarity
|
||||
* @param detectSegment
|
||||
* @param matchName
|
||||
* @return
|
||||
*/
|
||||
public double getSimilarity(String detectSegment, String matchName) {
|
||||
String detectSegmentLower = detectSegment == null ? null : detectSegment.toLowerCase();
|
||||
String matchNameLower = matchName == null ? null : matchName.toLowerCase();
|
||||
return 1 - (double) EditDistance.compute(detectSegmentLower, matchNameLower) / Math.max(matchName.length(),
|
||||
detectSegment.length());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* MatchStrategy encapsulates a concrete matching algorithm
|
||||
* executed during query or search process.
|
||||
*/
|
||||
public interface MatchStrategy<T> {
|
||||
|
||||
Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, Set<Long> detectDataSetIds);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
@Builder
|
||||
public class MatchText {
|
||||
|
||||
private String regText;
|
||||
|
||||
private String detectSegment;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
return true;
|
||||
}
|
||||
if (o == null || getClass() != o.getClass()) {
|
||||
return false;
|
||||
}
|
||||
MatchText that = (MatchText) o;
|
||||
return Objects.equals(regText, that.regText) && Objects.equals(detectSegment, that.detectSegment);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(regText, detectSegment);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class ModelWithSemanticType implements Serializable {
|
||||
|
||||
private Long model;
|
||||
private SchemaElementType schemaElementType;
|
||||
|
||||
public ModelWithSemanticType(Long model, SchemaElementType schemaElementType) {
|
||||
this.model = model;
|
||||
this.schemaElementType = schemaElementType;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class QueryFilterMapper extends BaseMapper {
|
||||
|
||||
private double similarity = 1.0;
|
||||
|
||||
@Override
|
||||
public void doMap(QueryContext queryContext) {
|
||||
Set<Long> dataSetIds = queryContext.getDataSetIds();
|
||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return;
|
||||
}
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
clearOtherSchemaElementMatch(dataSetIds, schemaMapInfo);
|
||||
for (Long dataSetId : dataSetIds) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMapInfo.getMatchedElements(dataSetId);
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
schemaMapInfo.setMatchedElements(dataSetId, schemaElementMatches);
|
||||
}
|
||||
addValueSchemaElementMatch(dataSetId, queryContext, schemaElementMatches);
|
||||
}
|
||||
}
|
||||
|
||||
private void clearOtherSchemaElementMatch(Set<Long> viewIds, SchemaMapInfo schemaMapInfo) {
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : schemaMapInfo.getDataSetElementMatches().entrySet()) {
|
||||
if (!viewIds.contains(entry.getKey())) {
|
||||
entry.getValue().clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addValueSchemaElementMatch(Long dataSetId, QueryContext queryContext,
|
||||
List<SchemaElementMatch> candidateElementMatches) {
|
||||
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||
if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) {
|
||||
return;
|
||||
}
|
||||
for (QueryFilter filter : queryFilters.getFilters()) {
|
||||
if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) {
|
||||
continue;
|
||||
}
|
||||
SchemaElement element = SchemaElement.builder()
|
||||
.id(filter.getElementID())
|
||||
.name(String.valueOf(filter.getValue()))
|
||||
.type(SchemaElementType.VALUE)
|
||||
.bizName(filter.getBizName())
|
||||
.dataSet(dataSetId)
|
||||
.build();
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.word(String.valueOf(filter.getValue()))
|
||||
.similarity(similarity)
|
||||
.detectWord(Constants.EMPTY)
|
||||
.build();
|
||||
candidateElementMatches.add(schemaElementMatch);
|
||||
}
|
||||
queryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches);
|
||||
}
|
||||
|
||||
private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter,
|
||||
List<SchemaElementMatch> schemaElementMatches) {
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatches.stream().filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||
if (schemaElementMatch.getElement().getId().equals(queryFilter.getElementID())
|
||||
&& schemaElementMatch.getWord().equals(String.valueOf(queryFilter.getValue()))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
|
||||
/**
|
||||
* A schema mapper identifies references to schema elements(metrics/dimensions/entities/values)
|
||||
* in user queries. It matches the query text against the knowledge base.
|
||||
*/
|
||||
public interface SchemaMapper {
|
||||
|
||||
void map(QueryContext queryContext);
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package com.tencent.supersonic.headless.chat.mapper;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.HanlpMapResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.KnowledgeBaseService;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.SearchService;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* SearchMatchStrategy encapsulates a concrete matching algorithm
|
||||
* executed during search process.
|
||||
*/
|
||||
@Service
|
||||
public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
|
||||
private static final int SEARCH_SIZE = 3;
|
||||
|
||||
@Autowired
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
|
||||
@Override
|
||||
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> originals,
|
||||
Set<Long> detectDataSetIds) {
|
||||
String text = queryContext.getQueryText();
|
||||
Map<Integer, Integer> regOffsetToLength = getRegOffsetToLength(originals);
|
||||
|
||||
List<Integer> detectIndexList = Lists.newArrayList();
|
||||
|
||||
for (Integer index = 0; index < text.length(); ) {
|
||||
|
||||
if (index < text.length()) {
|
||||
detectIndexList.add(index);
|
||||
}
|
||||
Integer regLength = regOffsetToLength.get(index);
|
||||
if (Objects.nonNull(regLength)) {
|
||||
index = index + regLength;
|
||||
} else {
|
||||
index++;
|
||||
}
|
||||
}
|
||||
Map<MatchText, List<HanlpMapResult>> regTextMap = new ConcurrentHashMap<>();
|
||||
detectIndexList.stream().parallel().forEach(detectIndex -> {
|
||||
String regText = text.substring(0, detectIndex);
|
||||
String detectSegment = text.substring(detectIndex);
|
||||
|
||||
if (StringUtils.isNotEmpty(detectSegment)) {
|
||||
List<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
|
||||
SearchService.SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
||||
List<HanlpMapResult> suffixHanlpMapResults = knowledgeBaseService.suffixSearch(
|
||||
detectSegment, SEARCH_SIZE, queryContext.getModelIdToDataSetIds(), detectDataSetIds);
|
||||
hanlpMapResults.addAll(suffixHanlpMapResults);
|
||||
// remove entity name where search
|
||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||
List<String> natures = entry.getNatures().stream()
|
||||
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
|
||||
.collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(natures)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}).collect(Collectors.toList());
|
||||
MatchText matchText = MatchText.builder()
|
||||
.regText(regText)
|
||||
.detectSegment(detectSegment)
|
||||
.build();
|
||||
regTextMap.put(matchText, hanlpMapResults);
|
||||
}
|
||||
}
|
||||
);
|
||||
return regTextMap;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean needDelete(HanlpMapResult oneRoundResult, HanlpMapResult existResult) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getMapKey(HanlpMapResult a) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
|
||||
String detectSegment, int offset) {
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package com.tencent.supersonic.headless.chat.parser;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.ParameterConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Service("HeadlessParserConfig")
|
||||
@Slf4j
|
||||
public class ParserConfig extends ParameterConfig {
|
||||
|
||||
public static final Parameter PARSER_STRATEGY_TYPE =
|
||||
new Parameter("s2.parser.s2sql.strategy", "ONE_PASS_SELF_CONSISTENCY",
|
||||
"LLM解析生成S2SQL策略",
|
||||
"ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql"
|
||||
+ "\nTWO_PASS_AUTO_COT_SELF_CONSISTENCY: 通过思维链且投票方式两步生成sql",
|
||||
"list", "Parser相关配置", Lists.newArrayList(
|
||||
"ONE_PASS_SELF_CONSISTENCY", "TWO_PASS_AUTO_COT_SELF_CONSISTENCY"));
|
||||
|
||||
public static final Parameter PARSER_LINKING_VALUE_ENABLE =
|
||||
new Parameter("s2.parser.linking.value.enable", "true",
|
||||
"是否将Mapper探测识别到的维度值提供给大模型", "为了数据安全考虑, 这里可进行开关选择",
|
||||
"bool", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD =
|
||||
new Parameter("s2.parser.text.length.threshold", "10",
|
||||
"用户输入文本长短阈值", "文本超过该阈值为长文本",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT =
|
||||
new Parameter("s2.parser.text.threshold.short", "0.5",
|
||||
"短文本匹配阈值",
|
||||
"由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用,"
|
||||
+ "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG =
|
||||
new Parameter("s2.parser.text.threshold.long", "0.8",
|
||||
"长文本匹配阈值", "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER =
|
||||
new Parameter("s2.parser.exemplar-recall.number", "10",
|
||||
"exemplar召回个数", "",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_FEW_SHOT_NUMBER =
|
||||
new Parameter("s2.parser.few-shot.number", "3",
|
||||
"few-shot样例个数", "样例越多效果可能越好,但token消耗越大",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER =
|
||||
new Parameter("s2.parser.self-consistency.number", "1",
|
||||
"self-consistency执行个数", "执行越多效果可能越好,但token消耗越大",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_SHOW_COUNT =
|
||||
new Parameter("s2.parser.show.count", "3",
|
||||
"解析结果展示个数", "前端展示的解析个数",
|
||||
"number", "Parser相关配置");
|
||||
|
||||
public static final Parameter PARSER_S2SQL_ENABLE =
|
||||
new Parameter("s2.parser.s2sql.switch", "true",
|
||||
"", "",
|
||||
"bool", "Parser相关配置");
|
||||
|
||||
@Override
|
||||
public List<Parameter> getSysParameters() {
|
||||
return Lists.newArrayList(
|
||||
PARSER_LINKING_VALUE_ENABLE,
|
||||
PARSER_FEW_SHOT_NUMBER,
|
||||
PARSER_SELF_CONSISTENCY_NUMBER,
|
||||
PARSER_SHOW_COUNT
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package com.tencent.supersonic.headless.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
|
||||
*/
|
||||
@Slf4j
|
||||
public class QueryTypeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
|
||||
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
|
||||
User user = queryContext.getUser();
|
||||
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
// 1.init S2SQL
|
||||
semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user);
|
||||
// 2.set queryType
|
||||
QueryType queryType = getQueryType(queryContext, semanticQuery);
|
||||
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||
}
|
||||
}
|
||||
|
||||
private QueryType getQueryType(QueryContext queryContext, SemanticQuery semanticQuery) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||
return QueryType.DETAIL;
|
||||
}
|
||||
//1. entity queryType
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL());
|
||||
List<String> whereFilterByTimeFields = filterByTimeFields(whereFields);
|
||||
if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(ids) && ids.stream()
|
||||
.anyMatch(whereFilterByTimeFields::contains)) {
|
||||
return QueryType.ID;
|
||||
}
|
||||
}
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
selectFields.addAll(whereFields);
|
||||
List<String> selectWhereFilterByTimeFields = filterByTimeFields(selectFields);
|
||||
if (CollectionUtils.isNotEmpty(selectWhereFilterByTimeFields)) {
|
||||
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
//If all the fields in the SELECT/WHERE statement are of tag type.
|
||||
if (CollectionUtils.isNotEmpty(tags)
|
||||
&& tags.containsAll(selectWhereFilterByTimeFields)) {
|
||||
return QueryType.DETAIL;
|
||||
}
|
||||
}
|
||||
}
|
||||
//2. metric queryType
|
||||
if (selectContainsMetric(sqlInfo, dataSetId, semanticSchema)) {
|
||||
return QueryType.METRIC;
|
||||
}
|
||||
return QueryType.DETAIL;
|
||||
}
|
||||
|
||||
private static List<String> filterByTimeFields(List<String> whereFields) {
|
||||
List<String> selectAndWhereFilterByTimeFields = whereFields
|
||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
return selectAndWhereFilterByTimeFields;
|
||||
}
|
||||
|
||||
private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) {
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
return selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
package com.tencent.supersonic.headless.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD_LONG;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_TEXT_LENGTH_THRESHOLD_SHORT;
|
||||
|
||||
|
||||
/**
|
||||
* This checker can be used by semantic parsers to check if query intent
|
||||
* has already been satisfied by current candidate queries. If so, current
|
||||
* parser could be skipped.
|
||||
*/
|
||||
@Slf4j
|
||||
public class SatisfactionChecker {
|
||||
|
||||
// check all the parse info in candidate
|
||||
public static boolean isSkip(QueryContext queryContext) {
|
||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||
if (query.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||
continue;
|
||||
}
|
||||
if (checkThreshold(queryContext.getQueryText(), query.getParseInfo())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static boolean checkThreshold(String queryText, SemanticParseInfo semanticParseInfo) {
|
||||
int queryTextLength = queryText.replaceAll(" ", "").length();
|
||||
double degree = semanticParseInfo.getScore() / queryTextLength;
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
int textLengthThreshold =
|
||||
Integer.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD));
|
||||
double longTextLengthThreshold =
|
||||
Double.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD_LONG));
|
||||
double shortTextLengthThreshold =
|
||||
Double.valueOf(parserConfig.getParameterValue(PARSER_TEXT_LENGTH_THRESHOLD_SHORT));
|
||||
|
||||
if (queryTextLength > textLengthThreshold) {
|
||||
if (degree < longTextLengthThreshold) {
|
||||
return false;
|
||||
}
|
||||
} else if (degree < shortTextLengthThreshold) {
|
||||
return false;
|
||||
}
|
||||
log.info("queryMode:{}, degree:{}, parse info:{}",
|
||||
semanticParseInfo.getQueryMode(), degree, semanticParseInfo);
|
||||
return true;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.tencent.supersonic.headless.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
|
||||
/**
|
||||
* A semantic parser understands user queries and generates semantic query statement.
|
||||
* SuperSonic leverages a combination of rule-based and LLM-based parsers,
|
||||
* each of which deals with specific scenarios.
|
||||
*/
|
||||
public interface SemanticParser {
|
||||
|
||||
void parse(QueryContext queryContext, ChatContext chatContext);
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class DataSetMatchResult {
|
||||
private Integer count = 0;
|
||||
private double maxSimilarity;
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
|
||||
import java.util.Set;
|
||||
|
||||
public interface DataSetResolver {
|
||||
|
||||
Long resolve(QueryContext queryContext, Set<Long> restrictiveModels);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class Exemplar {
|
||||
|
||||
private String question;
|
||||
|
||||
private String questionAugmented;
|
||||
|
||||
private String dbSchema;
|
||||
|
||||
private String sql;
|
||||
|
||||
private String generatedSchemaLinkingCoT;
|
||||
|
||||
private String generatedSchemaLinkings;
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.fasterxml.jackson.core.type.TypeReference;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class ExemplarManager {
|
||||
|
||||
private static final String EXAMPLE_JSON_FILE = "s2ql_exemplar.json";
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
private TypeReference<List<Exemplar>> valueTypeRef = new TypeReference<List<Exemplar>>() {
|
||||
};
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
public List<Exemplar> getExemplars() throws IOException {
|
||||
ClassPathResource resource = new ClassPathResource(EXAMPLE_JSON_FILE);
|
||||
InputStream inputStream = resource.getInputStream();
|
||||
return JsonUtil.INSTANCE.getObjectMapper().readValue(inputStream, valueTypeRef);
|
||||
}
|
||||
|
||||
public void addExemplars(List<Exemplar> exemplars, String collectionName) {
|
||||
List<EmbeddingQuery> queries = new ArrayList<>();
|
||||
for (int i = 0; i < exemplars.size(); i++) {
|
||||
Exemplar exemplar = exemplars.get(i);
|
||||
String question = exemplar.getQuestion();
|
||||
Map<String, Object> metaDataMap = JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class);
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(String.valueOf(i));
|
||||
embeddingQuery.setQuery(question);
|
||||
embeddingQuery.setMetadata(metaDataMap);
|
||||
queries.add(embeddingQuery);
|
||||
}
|
||||
s2EmbeddingStore.addQuery(collectionName, queries);
|
||||
}
|
||||
|
||||
public List<Map<String, String>> recallExemplars(String queryText, int maxResults) {
|
||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(Collections.singletonList(queryText))
|
||||
.queryEmbeddings(null).build();
|
||||
|
||||
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(collectionName, retrieveQuery,
|
||||
maxResults);
|
||||
List<Map<String, String>> result = new ArrayList<>();
|
||||
if (CollectionUtils.isEmpty(resultList)) {
|
||||
return result;
|
||||
}
|
||||
for (Retrieval retrieval : resultList.get(0).getRetrieval()) {
|
||||
if (Objects.nonNull(retrieval.getMetadata()) && !retrieval.getMetadata().isEmpty()) {
|
||||
Map<String, String> convertedMap = retrieval.getMetadata().entrySet().stream()
|
||||
.collect(Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue())));
|
||||
result.add(convertedMap);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class HeuristicDataSetResolver implements DataSetResolver {
|
||||
|
||||
protected static Long selectDataSetBySchemaElementMatchScore(Map<Long, SemanticQuery> dataSetQueryModes,
|
||||
SchemaMapInfo schemaMap) {
|
||||
//dataSet count priority
|
||||
Long dataSetIdByDataSetCount = getDataSetIdByMatchDataSetScore(schemaMap);
|
||||
if (Objects.nonNull(dataSetIdByDataSetCount)) {
|
||||
log.info("selectDataSet by dataSet count:{}", dataSetIdByDataSetCount);
|
||||
return dataSetIdByDataSetCount;
|
||||
}
|
||||
|
||||
Map<Long, DataSetMatchResult> dataSetTypeMap = getDataSetTypeMap(schemaMap);
|
||||
if (dataSetTypeMap.size() == 1) {
|
||||
Long dataSetSelect = new ArrayList<>(dataSetTypeMap.entrySet()).get(0).getKey();
|
||||
if (dataSetQueryModes.containsKey(dataSetSelect)) {
|
||||
log.info("selectDataSet with only one DataSet [{}]", dataSetSelect);
|
||||
return dataSetSelect;
|
||||
}
|
||||
} else {
|
||||
Entry<Long, DataSetMatchResult> maxDataSet = dataSetTypeMap.entrySet().stream()
|
||||
.filter(entry -> dataSetQueryModes.containsKey(entry.getKey()))
|
||||
.sorted((o1, o2) -> {
|
||||
int difference = o2.getValue().getCount() - o1.getValue().getCount();
|
||||
if (difference == 0) {
|
||||
return (int) ((o2.getValue().getMaxSimilarity()
|
||||
- o1.getValue().getMaxSimilarity()) * 100);
|
||||
}
|
||||
return difference;
|
||||
}).findFirst().orElse(null);
|
||||
if (maxDataSet != null) {
|
||||
log.info("selectDataSet with multiple DataSets [{}]", maxDataSet.getKey());
|
||||
return maxDataSet.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static Long getDataSetIdByMatchDataSetScore(SchemaMapInfo schemaMap) {
|
||||
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = schemaMap.getDataSetElementMatches();
|
||||
// calculate dataSet match score, matched element gets 1.0 point, and inherit element gets 0.5 point
|
||||
Map<Long, Double> dataSetIdToDataSetScore = new HashMap<>();
|
||||
if (Objects.nonNull(dataSetElementMatches)) {
|
||||
for (Entry<Long, List<SchemaElementMatch>> dataSetElementMatch : dataSetElementMatches.entrySet()) {
|
||||
Long dataSetId = dataSetElementMatch.getKey();
|
||||
List<Double> dataSetMatchesScore = dataSetElementMatch.getValue().stream()
|
||||
.filter(elementMatch -> elementMatch.getSimilarity() >= 1)
|
||||
.filter(elementMatch -> SchemaElementType.DATASET.equals(elementMatch.getElement().getType()))
|
||||
.map(elementMatch -> elementMatch.isInherited() ? 0.5 : 1.0).collect(Collectors.toList());
|
||||
|
||||
if (!CollectionUtils.isEmpty(dataSetMatchesScore)) {
|
||||
// get sum of dataSet match score
|
||||
double score = dataSetMatchesScore.stream().mapToDouble(Double::doubleValue).sum();
|
||||
dataSetIdToDataSetScore.put(dataSetId, score);
|
||||
}
|
||||
}
|
||||
Entry<Long, Double> maxDataSetScore = dataSetIdToDataSetScore.entrySet().stream()
|
||||
.max(Comparator.comparingDouble(Entry::getValue)).orElse(null);
|
||||
log.info("maxDataSetCount:{},dataSetIdToDataSetCount:{}", maxDataSetScore, dataSetIdToDataSetScore);
|
||||
if (Objects.nonNull(maxDataSetScore)) {
|
||||
return maxDataSetScore.getKey();
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Map<Long, DataSetMatchResult> getDataSetTypeMap(SchemaMapInfo schemaMap) {
|
||||
Map<Long, DataSetMatchResult> dataSetCount = new HashMap<>();
|
||||
for (Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getDataSetElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> schemaElementMatches = schemaMap.getMatchedElements(entry.getKey());
|
||||
if (schemaElementMatches != null && schemaElementMatches.size() > 0) {
|
||||
if (!dataSetCount.containsKey(entry.getKey())) {
|
||||
dataSetCount.put(entry.getKey(), new DataSetMatchResult());
|
||||
}
|
||||
DataSetMatchResult dataSetMatchResult = dataSetCount.get(entry.getKey());
|
||||
Set<SchemaElementType> schemaElementTypes = new HashSet<>();
|
||||
schemaElementMatches.stream()
|
||||
.forEach(schemaElementMatch -> schemaElementTypes.add(
|
||||
schemaElementMatch.getElement().getType()));
|
||||
SchemaElementMatch schemaElementMatchMax = schemaElementMatches.stream()
|
||||
.sorted((o1, o2) ->
|
||||
((int) ((o2.getSimilarity() - o1.getSimilarity()) * 100))
|
||||
).findFirst().orElse(null);
|
||||
if (schemaElementMatchMax != null) {
|
||||
dataSetMatchResult.setMaxSimilarity(schemaElementMatchMax.getSimilarity());
|
||||
}
|
||||
dataSetMatchResult.setCount(schemaElementTypes.size());
|
||||
|
||||
}
|
||||
}
|
||||
return dataSetCount;
|
||||
}
|
||||
|
||||
public Long resolve(QueryContext queryContext, Set<Long> agentDataSetIds) {
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
|
||||
if (CollectionUtils.isNotEmpty(agentDataSetIds)) {
|
||||
matchedDataSets.retainAll(agentDataSetIds);
|
||||
}
|
||||
Map<Long, SemanticQuery> dataSetQueryModes = new HashMap<>();
|
||||
for (Long dataSetIds : matchedDataSets) {
|
||||
dataSetQueryModes.put(dataSetIds, null);
|
||||
}
|
||||
if (dataSetQueryModes.size() == 1) {
|
||||
return dataSetQueryModes.keySet().stream().findFirst().get();
|
||||
}
|
||||
return selectDataSetBySchemaElementMatchScore(dataSetQueryModes, mapInfo);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public class InputFormat {
|
||||
|
||||
public static final String SEPERATOR = "\n\n";
|
||||
|
||||
public static String format(String template, List<String> templateKey,
|
||||
List<Map<String, String>> toFormatList) {
|
||||
List<String> result = new ArrayList<>();
|
||||
|
||||
for (Map<String, String> formatItem : toFormatList) {
|
||||
Map<String, String> retrievalMeta = subDict(formatItem, templateKey);
|
||||
result.add(format(template, retrievalMeta));
|
||||
}
|
||||
|
||||
return String.join(SEPERATOR, result);
|
||||
}
|
||||
|
||||
public static String format(String input, Map<String, String> replacements) {
|
||||
for (Map.Entry<String, String> entry : replacements.entrySet()) {
|
||||
input = input.replace(entry.getKey(), entry.getValue());
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
private static Map<String, String> subDict(Map<String, String> dict, List<String> keys) {
|
||||
Map<String, String> subDict = new HashMap<>();
|
||||
for (String key : keys) {
|
||||
if (dict.containsKey(key)) {
|
||||
subDict.put(key, dict.get(key));
|
||||
}
|
||||
}
|
||||
return subDict;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
/**
|
||||
* LLMProxy based on langchain4j Java version.
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
public class JavaLLMProxy implements LLMProxy {
|
||||
|
||||
public LLMResp text2sql(LLMReq llmReq) {
|
||||
|
||||
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType());
|
||||
String modelName = llmReq.getSchema().getDataSetName();
|
||||
LLMResp result = sqlGenStrategy.generate(llmReq);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
return result;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
@Data
|
||||
public class LLMParserConfig {
|
||||
|
||||
@Value("${s2.parser.url:}")
|
||||
private String url;
|
||||
|
||||
@Value("${s2.query2sql.path:/query2sql}")
|
||||
private String queryToSqlPath;
|
||||
|
||||
@Value("${s2.dimension.topn:10}")
|
||||
private Integer dimensionTopN;
|
||||
|
||||
@Value("${s2.metric.topn:10}")
|
||||
private Integer metricTopN;
|
||||
|
||||
@Value("${s2.tag.topn:20}")
|
||||
private Integer tagTopN;
|
||||
|
||||
@Value("${s2.all.model:false}")
|
||||
private Boolean allModel;
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
|
||||
/**
|
||||
* LLMProxy encapsulates functions performed by LLMs so that multiple
|
||||
* orchestration frameworks (e.g. LangChain in python, LangChain4j in java)
|
||||
* could be used.
|
||||
*/
|
||||
public interface LLMProxy {
|
||||
|
||||
LLMResp text2sql(LLMReq llmReq);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.utils.S2SqlDateHelper;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.parser.SatisfactionChecker;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_LINKING_VALUE_ENABLE;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_STRATEGY_TYPE;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LLMRequestService {
|
||||
|
||||
@Autowired
|
||||
private LLMParserConfig llmParserConfig;
|
||||
|
||||
@Autowired
|
||||
private ParserConfig parserConfig;
|
||||
|
||||
public boolean isSkip(QueryContext queryCtx) {
|
||||
if (!queryCtx.getText2SQLType().enableLLM()) {
|
||||
log.info("not enable llm, skip");
|
||||
return true;
|
||||
}
|
||||
|
||||
if (SatisfactionChecker.isSkip(queryCtx)) {
|
||||
log.info("skip {}, queryText:{}", LLMSqlParser.class, queryCtx.getQueryText());
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
public Long getDataSetId(QueryContext queryCtx) {
|
||||
DataSetResolver dataSetResolver = ComponentFactory.getModelResolver();
|
||||
return dataSetResolver.resolve(queryCtx, queryCtx.getDataSetIds());
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
|
||||
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
|
||||
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
|
||||
String queryText = queryCtx.getQueryText();
|
||||
|
||||
LLMReq llmReq = new LLMReq();
|
||||
|
||||
llmReq.setQueryText(queryText);
|
||||
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
|
||||
llmReq.setFilterCondition(filterCondition);
|
||||
|
||||
LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema();
|
||||
llmSchema.setDataSetId(dataSetId);
|
||||
llmSchema.setDataSetName(dataSetIdToName.get(dataSetId));
|
||||
llmSchema.setDomainName(dataSetIdToName.get(dataSetId));
|
||||
|
||||
List<String> fieldNameList = getFieldNameList(queryCtx, dataSetId, llmParserConfig);
|
||||
fieldNameList.add(TimeDimensionEnum.DAY.getChName());
|
||||
llmSchema.setFieldNameList(fieldNameList);
|
||||
|
||||
llmSchema.setMetrics(getMatchedMetrics(queryCtx, dataSetId));
|
||||
llmSchema.setDimensions(getMatchedDimensions(queryCtx, dataSetId));
|
||||
llmSchema.setTerms(getTerms(queryCtx, dataSetId));
|
||||
llmReq.setSchema(llmSchema);
|
||||
|
||||
String priorExts = getPriorExts(queryCtx, fieldNameList);
|
||||
llmReq.setPriorExts(priorExts);
|
||||
|
||||
List<LLMReq.ElementValue> linking = new ArrayList<>();
|
||||
boolean linkingValueEnabled = Boolean.valueOf(parserConfig.getParameterValue(PARSER_LINKING_VALUE_ENABLE));
|
||||
|
||||
if (linkingValueEnabled) {
|
||||
linking.addAll(linkingValues);
|
||||
}
|
||||
llmReq.setLinking(linking);
|
||||
|
||||
String currentDate = S2SqlDateHelper.getReferenceDate(queryCtx, dataSetId);
|
||||
if (StringUtils.isEmpty(currentDate)) {
|
||||
currentDate = DateUtils.getBeforeDate(0);
|
||||
}
|
||||
llmReq.setCurrentDate(currentDate);
|
||||
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
||||
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
public LLMResp runText2SQL(LLMReq llmReq) {
|
||||
return ComponentFactory.getLLMProxy().text2sql(llmReq);
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
|
||||
|
||||
Set<String> fieldNameList = getMatchedFieldNames(queryCtx, dataSetId);
|
||||
|
||||
results.addAll(fieldNameList);
|
||||
return new ArrayList<>(results);
|
||||
}
|
||||
|
||||
protected List<LLMReq.Term> getTerms(QueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
return matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.TERM.equals(elementType);
|
||||
}).map(schemaElementMatch -> {
|
||||
LLMReq.Term term = new LLMReq.Term();
|
||||
term.setName(schemaElementMatch.getElement().getName());
|
||||
term.setDescription(schemaElementMatch.getElement().getDescription());
|
||||
term.setAlias(schemaElementMatch.getElement().getAlias());
|
||||
return term;
|
||||
}).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private String getPriorExts(QueryContext queryContext, List<String> fieldNameList) {
|
||||
StringBuilder extraInfoSb = new StringBuilder();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
Map<String, String> fieldNameToDataFormatType = semanticSchema.getMetrics()
|
||||
.stream().filter(metricSchemaResp -> Objects.nonNull(metricSchemaResp.getDataFormatType()))
|
||||
.flatMap(metricSchemaResp -> {
|
||||
Set<Pair<String, String>> result = new HashSet<>();
|
||||
String dataFormatType = metricSchemaResp.getDataFormatType();
|
||||
result.add(Pair.of(metricSchemaResp.getName(), dataFormatType));
|
||||
List<String> aliasList = metricSchemaResp.getAlias();
|
||||
if (!CollectionUtils.isEmpty(aliasList)) {
|
||||
for (String alias : aliasList) {
|
||||
result.add(Pair.of(alias, dataFormatType));
|
||||
}
|
||||
}
|
||||
return result.stream();
|
||||
}).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1));
|
||||
|
||||
for (String fieldName : fieldNameList) {
|
||||
String dataFormatType = fieldNameToDataFormatType.get(fieldName);
|
||||
if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType)
|
||||
|| DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) {
|
||||
String format = String.format("%s的计量单位是%s", fieldName, "小数; ");
|
||||
extraInfoSb.append(format);
|
||||
}
|
||||
}
|
||||
return extraInfoSb.toString();
|
||||
}
|
||||
|
||||
public List<LLMReq.ElementValue> getValues(QueryContext queryCtx, Long dataSetId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
Set<LLMReq.ElementValue> valueMatches = matchedElements
|
||||
.stream()
|
||||
.filter(elementMatch -> !elementMatch.isInherited())
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType type = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
|
||||
})
|
||||
.map(elementMatch -> {
|
||||
LLMReq.ElementValue elementValue = new LLMReq.ElementValue();
|
||||
elementValue.setFieldName(itemIdToName.get(elementMatch.getElement().getId()));
|
||||
elementValue.setFieldValue(elementMatch.getWord());
|
||||
return elementValue;
|
||||
}).collect(Collectors.toSet());
|
||||
return new ArrayList<>(valueMatches);
|
||||
}
|
||||
|
||||
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId);
|
||||
return elements.stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
|
||||
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
Set<String> results = new HashSet<>();
|
||||
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
results.addAll(dimensions);
|
||||
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getMetricTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
results.addAll(metrics);
|
||||
return results;
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMatchedMetrics(QueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
List<SchemaElement> schemaElements = matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.METRIC.equals(elementType);
|
||||
})
|
||||
.map(schemaElementMatch -> {
|
||||
return schemaElementMatch.getElement();
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
return schemaElements;
|
||||
}
|
||||
|
||||
protected List<SchemaElement> getMatchedDimensions(QueryContext queryCtx, Long dataSetId) {
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
List<SchemaElement> schemaElements = matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.DIMENSION.equals(elementType);
|
||||
})
|
||||
.map(schemaElementMatch -> {
|
||||
return schemaElementMatch.getElement();
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
return schemaElements;
|
||||
}
|
||||
|
||||
protected Set<String> getMatchedFieldNames(QueryContext queryCtx, Long dataSetId) {
|
||||
Map<Long, String> itemIdToName = getItemIdToName(queryCtx, dataSetId);
|
||||
List<SchemaElementMatch> matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId);
|
||||
if (CollectionUtils.isEmpty(matchedElements)) {
|
||||
return new HashSet<>();
|
||||
}
|
||||
Set<String> fieldNameList = matchedElements.stream()
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.METRIC.equals(elementType)
|
||||
|| SchemaElementType.DIMENSION.equals(elementType)
|
||||
|| SchemaElementType.VALUE.equals(elementType);
|
||||
})
|
||||
.map(schemaElementMatch -> {
|
||||
SchemaElement element = schemaElementMatch.getElement();
|
||||
SchemaElementType elementType = element.getType();
|
||||
if (SchemaElementType.VALUE.equals(elementType)) {
|
||||
return itemIdToName.get(element.getId());
|
||||
}
|
||||
return schemaElementMatch.getWord();
|
||||
})
|
||||
.collect(Collectors.toSet());
|
||||
return fieldNameList;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlEqualHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LLMResponseService {
|
||||
|
||||
public SemanticParseInfo addParseInfo(QueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) {
|
||||
if (Objects.isNull(weight)) {
|
||||
weight = 0D;
|
||||
}
|
||||
LLMSemanticQuery semanticQuery = QueryManager.createLLMQuery(LLMSqlQuery.QUERY_MODE);
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
parseInfo.setDataSet(queryCtx.getSemanticSchema().getDataSet(parseResult.getDataSetId()));
|
||||
parseInfo.getElementMatches().addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getDataSetId()));
|
||||
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, parseResult);
|
||||
properties.put("type", "internal");
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
parseInfo.getSqlInfo().setS2SQL(s2SQL);
|
||||
queryCtx.getCandidateQueries().add(semanticQuery);
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
public Map<String, LLMSqlResp> getDeduplicationSqlResp(LLMResp llmResp) {
|
||||
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
|
||||
return llmResp.getSqlRespMap();
|
||||
}
|
||||
Map<String, LLMSqlResp> result = new HashMap<>();
|
||||
for (Map.Entry<String, LLMSqlResp> entry : llmResp.getSqlRespMap().entrySet()) {
|
||||
String key = entry.getKey();
|
||||
if (result.keySet().stream().anyMatch(existKey -> SqlEqualHelper.equals(existKey, key))) {
|
||||
continue;
|
||||
}
|
||||
result.put(key, entry.getValue());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* LLMSqlParser uses large language model to understand query semantics and
|
||||
* generate S2SQL statements to be executed by the semantic query engine.
|
||||
*/
|
||||
@Slf4j
|
||||
public class LLMSqlParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryCtx, ChatContext chatCtx) {
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
//1.determine whether to skip this parser.
|
||||
if (requestService.isSkip(queryCtx)) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
//2.get dataSetId from queryCtx and chatCtx.
|
||||
Long dataSetId = requestService.getDataSetId(queryCtx);
|
||||
if (dataSetId == null) {
|
||||
return;
|
||||
}
|
||||
log.info("Generate query statement for dataSetId:{}", dataSetId);
|
||||
|
||||
//3.invoke LLM service to do parsing.
|
||||
List<LLMReq.ElementValue> linkingValues = requestService.getValues(queryCtx, dataSetId);
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
LLMReq llmReq = requestService.getLlmReq(queryCtx, dataSetId, semanticSchema, linkingValues);
|
||||
LLMResp llmResp = requestService.runText2SQL(llmReq);
|
||||
if (Objects.isNull(llmResp)) {
|
||||
return;
|
||||
}
|
||||
|
||||
//4. deduplicate the S2SQL result list and build parserInfo
|
||||
LLMResponseService responseService = ContextUtils.getBean(LLMResponseService.class);
|
||||
Map<String, LLMSqlResp> deduplicationSqlResp = responseService.getDeduplicationSqlResp(llmResp);
|
||||
ParseResult parseResult = ParseResult.builder()
|
||||
.dataSetId(dataSetId)
|
||||
.llmReq(llmReq)
|
||||
.llmResp(llmResp)
|
||||
.linkingValues(linkingValues)
|
||||
.build();
|
||||
|
||||
if (MapUtils.isEmpty(deduplicationSqlResp)) {
|
||||
if (StringUtils.isNotBlank(llmResp.getSqlOutput())) {
|
||||
responseService.addParseInfo(queryCtx, parseResult, llmResp.getSqlOutput(), 1D);
|
||||
}
|
||||
} else {
|
||||
deduplicationSqlResp.forEach((sql, sqlResp) -> {
|
||||
if (StringUtils.isNotBlank(sql)) {
|
||||
responseService.addParseInfo(queryCtx, parseResult, sql, sqlResp.getSqlWeight());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to parse query:", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,100 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.recall exemplars
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
|
||||
List<List<Map<String, String>>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||
|
||||
//2.generate sql generation prompt for each self-consistency inference
|
||||
Map<Prompt, List<Map<String, String>>> prompt2Exemplar = new HashMap<>();
|
||||
for (List<Map<String, String>> exemplars : exemplarsList) {
|
||||
Prompt prompt = generatePrompt(llmReq, exemplars);
|
||||
prompt2Exemplar.put(prompt, exemplars);
|
||||
}
|
||||
|
||||
//3.perform multiple self-consistency inferences parallelly
|
||||
Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>();
|
||||
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toSystemMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = response.content().text();
|
||||
prompt2Output.put(prompt, result);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
|
||||
}
|
||||
);
|
||||
|
||||
//4.format response.
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(
|
||||
Lists.newArrayList(prompt2Output.values()));
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));
|
||||
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
private Prompt generatePrompt(LLMReq llmReq, List<Map<String, String>> fewshotExampleList) {
|
||||
String instruction = ""
|
||||
+ "#Role: You are a data analyst experienced in SQL languages.\n"
|
||||
+ "#Task: You will be provided a natural language query asked by business users,"
|
||||
+ "please convert it to a SQL query so that relevant answer could be returned to the user "
|
||||
+ "by executing the SQL query against underlying database.\n"
|
||||
+ "#Rules:\n"
|
||||
+ "1.ALWAYS use `数据日期` as the date field.\n"
|
||||
+ "2.ALWAYS use `datediff()` as the date function.\n"
|
||||
+ "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query.\n"
|
||||
+ "4.ONLY respond with the converted SQL statement.\n"
|
||||
+ "#Exemplars:\n%s"
|
||||
+ "#UserQuery: %s "
|
||||
+ "#DatabaseMetadata: %s "
|
||||
+ "#SQL: ";
|
||||
|
||||
StringBuilder exemplarsStr = new StringBuilder();
|
||||
for (Map<String, String> example : fewshotExampleList) {
|
||||
String metadata = example.get("dbSchema");
|
||||
String question = example.get("questionAugmented");
|
||||
String sql = example.get("sql");
|
||||
String exemplarStr = String.format("#UserQuery: %s #DatabaseMetadata: %s #SQL: %s\n",
|
||||
question, metadata, sql);
|
||||
exemplarsStr.append(exemplarStr);
|
||||
}
|
||||
|
||||
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
|
||||
String dataSemanticsStr = promptHelper.buildMetadataStr(llmReq);
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String promptStr = String.format(instruction, exemplarsStr, questionAugmented, dataSemanticsStr);
|
||||
|
||||
return PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_SELF_CONSISTENCY, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class OutputFormat {
|
||||
|
||||
public static String getSchemaLink(String schemaLink) {
|
||||
String reult = "";
|
||||
try {
|
||||
reult = schemaLink.trim();
|
||||
String pattern = "Schema_links:(.*)";
|
||||
Pattern regexPattern = Pattern.compile(pattern, Pattern.DOTALL);
|
||||
Matcher matcher = regexPattern.matcher(reult);
|
||||
if (matcher.find()) {
|
||||
return matcher.group(1).trim();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return reult;
|
||||
}
|
||||
|
||||
public static String getSql(String sqlOutput) {
|
||||
String sql = "";
|
||||
try {
|
||||
sqlOutput = sqlOutput.trim();
|
||||
String pattern = "SQL:(.*)";
|
||||
Pattern regexPattern = Pattern.compile(pattern);
|
||||
Matcher matcher = regexPattern.matcher(sqlOutput);
|
||||
if (matcher.find()) {
|
||||
return matcher.group(1);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return sql;
|
||||
}
|
||||
|
||||
public static String getSchemaLinks(String text) {
|
||||
String schemaLinks = "";
|
||||
try {
|
||||
text = text.trim();
|
||||
String pattern = "Schema_links:(\\[.*?\\])|Schema_links: (\\[.*?\\])";
|
||||
Pattern regexPattern = Pattern.compile(pattern);
|
||||
Matcher matcher = regexPattern.matcher(text);
|
||||
|
||||
if (matcher.find()) {
|
||||
if (matcher.group(1) != null) {
|
||||
schemaLinks = matcher.group(1);
|
||||
} else if (matcher.group(2) != null) {
|
||||
schemaLinks = matcher.group(2);
|
||||
}
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
|
||||
return schemaLinks;
|
||||
}
|
||||
|
||||
public static Pair<String, Map<String, Double>> selfConsistencyVote(List<String> inputList) {
|
||||
Map<String, Integer> inputCounts = new HashMap<>();
|
||||
for (String input : inputList) {
|
||||
inputCounts.put(input, inputCounts.getOrDefault(input, 0) + 1);
|
||||
}
|
||||
|
||||
String inputMax = null;
|
||||
int maxCount = 0;
|
||||
int inputSize = inputList.size();
|
||||
Map<String, Double> votePercentage = new HashMap<>();
|
||||
for (Map.Entry<String, Integer> entry : inputCounts.entrySet()) {
|
||||
String input = entry.getKey();
|
||||
int count = entry.getValue();
|
||||
if (count > maxCount) {
|
||||
inputMax = input;
|
||||
maxCount = count;
|
||||
}
|
||||
double percentage = (double) count / inputSize;
|
||||
votePercentage.put(input, percentage);
|
||||
}
|
||||
return Pair.of(inputMax, votePercentage);
|
||||
}
|
||||
|
||||
public static List<String> formatList(List<String> toFormatList) {
|
||||
List<String> results = new ArrayList<>();
|
||||
for (String toFormat : toFormatList) {
|
||||
List<String> items = new ArrayList<>();
|
||||
String[] split = toFormat.replace("[", "").replace("]", "").split(",");
|
||||
for (String item : split) {
|
||||
items.add(item.trim());
|
||||
}
|
||||
Collections.sort(items);
|
||||
String result = "[" + String.join(",", items) + "]";
|
||||
results.add(result);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Map<String, String>> sqlExamples,
|
||||
Map<String, Double> sqlMap) {
|
||||
if (sqlMap == null) {
|
||||
return new HashMap<>();
|
||||
}
|
||||
return sqlMap.entrySet().stream()
|
||||
.collect(Collectors.toMap(
|
||||
Map.Entry::getKey,
|
||||
entry -> LLMSqlResp.builder().sqlWeight(entry.getValue()).fewShots(sqlExamples).build())
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class ParseResult {
|
||||
|
||||
private Long dataSetId;
|
||||
|
||||
private LLMReq llmReq;
|
||||
|
||||
private LLMResp llmResp;
|
||||
|
||||
private QueryReq request;
|
||||
|
||||
private List<LLMReq.ElementValue> linkingValues;
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_FEW_SHOT_NUMBER;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_SELF_CONSISTENCY_NUMBER;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class PromptHelper {
|
||||
|
||||
@Autowired
|
||||
private ParserConfig parserConfig;
|
||||
|
||||
@Autowired
|
||||
private ExemplarManager exemplarManager;
|
||||
|
||||
public List<List<Map<String, String>>> getFewShotExemplars(LLMReq llmReq) {
|
||||
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||
|
||||
List<Map<String, String>> exemplars = exemplarManager.recallExemplars(llmReq.getQueryText(),
|
||||
exemplarRecallNumber);
|
||||
List<List<Map<String, String>>> results = new ArrayList<>();
|
||||
|
||||
// use random collection of exemplars for each self-consistency inference
|
||||
for (int i = 0; i < selfConsistencyNumber; i++) {
|
||||
List<Map<String, String>> shuffledList = new ArrayList<>(exemplars);
|
||||
Collections.shuffle(shuffledList);
|
||||
results.add(shuffledList.subList(0, fewShotNumber));
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
public Pair<String, String> transformQuestionPrompt(LLMReq llmReq) {
|
||||
String tableName = llmReq.getSchema().getDataSetName();
|
||||
List<String> fieldNameList = llmReq.getSchema().getFieldNameList();
|
||||
List<LLMReq.ElementValue> linkedValues = llmReq.getLinking();
|
||||
String currentDate = llmReq.getCurrentDate();
|
||||
String priorExts = llmReq.getPriorExts();
|
||||
|
||||
String dbSchema = "Table: " + tableName + ", Columns = " + fieldNameList;
|
||||
|
||||
List<String> priorLinkingList = new ArrayList<>();
|
||||
for (LLMReq.ElementValue value : linkedValues) {
|
||||
String fieldName = value.getFieldName();
|
||||
String fieldValue = value.getFieldValue();
|
||||
priorLinkingList.add("‘" + fieldValue + "‘是一个‘" + fieldName + "‘");
|
||||
}
|
||||
String currentDataStr = "当前的日期是" + currentDate;
|
||||
String linkingListStr = String.join(",", priorLinkingList);
|
||||
String termStr = buildTermStr(llmReq);
|
||||
String questionAugmented = String.format("%s (补充信息:%s;%s;%s;%s)", llmReq.getQueryText(),
|
||||
linkingListStr, currentDataStr, termStr, priorExts);
|
||||
|
||||
return Pair.of(dbSchema, questionAugmented);
|
||||
}
|
||||
|
||||
public String buildMetadataStr(LLMReq llmReq) {
|
||||
String tableStr = llmReq.getSchema().getDataSetName();
|
||||
StringBuilder metricStr = new StringBuilder();
|
||||
StringBuilder dimensionStr = new StringBuilder();
|
||||
|
||||
llmReq.getSchema().getMetrics().stream().forEach(
|
||||
metric -> {
|
||||
metricStr.append(metric.getName());
|
||||
if (StringUtils.isNotEmpty(metric.getDescription())) {
|
||||
metricStr.append(" COMMENT '" + metric.getDescription() + "'");
|
||||
}
|
||||
if (StringUtils.isNotEmpty(metric.getDefaultAgg())) {
|
||||
metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'");
|
||||
}
|
||||
metricStr.append(",");
|
||||
}
|
||||
);
|
||||
|
||||
llmReq.getSchema().getDimensions().stream().forEach(
|
||||
dimension -> {
|
||||
dimensionStr.append(dimension.getName());
|
||||
if (StringUtils.isNotEmpty(dimension.getDescription())) {
|
||||
dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'");
|
||||
}
|
||||
dimensionStr.append(",");
|
||||
}
|
||||
);
|
||||
|
||||
String template = "Table: %s, Metrics: [%s], Dimensions: [%s]";
|
||||
|
||||
|
||||
return String.format(template, tableStr, metricStr, dimensionStr);
|
||||
}
|
||||
|
||||
private String buildTermStr(LLMReq llmReq) {
|
||||
List<LLMReq.Term> terms = llmReq.getSchema().getTerms();
|
||||
StringBuilder termsDesc = new StringBuilder();
|
||||
if (!CollectionUtils.isEmpty(terms)) {
|
||||
termsDesc.append("相关业务术语:");
|
||||
for (int idx = 0; idx < terms.size(); idx++) {
|
||||
LLMReq.Term term = terms.get(idx);
|
||||
String name = term.getName();
|
||||
String description = term.getDescription();
|
||||
List<String> alias = term.getAlias();
|
||||
String descPart = StringUtils.isBlank(description) ? "" : String.format(",它通常是指<%s>", description);
|
||||
String aliasPart = CollectionUtils.isEmpty(alias) ? "" : String.format(",类似的表达还有%s", alias);
|
||||
termsDesc.append(String.format("%d.<%s>是业务术语%s%s;", idx + 1, name, descPart, aliasPart));
|
||||
}
|
||||
if (termsDesc.length() > 0) {
|
||||
termsDesc.setLength(termsDesc.length() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
return termsDesc.toString();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import java.net.URL;
|
||||
import java.util.ArrayList;
|
||||
|
||||
/**
|
||||
* PythonLLMProxy sends requests to LangChain-based python service.
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
public class PythonLLMProxy implements LLMProxy {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
public LLMResp text2sql(LLMReq llmReq) {
|
||||
long startTime = System.currentTimeMillis();
|
||||
log.info("requestLLM request, llmReq:{}", llmReq);
|
||||
keyPipelineLog.info("PythonLLMProxy llmReq:{}", llmReq);
|
||||
try {
|
||||
LLMParserConfig llmParserConfig = ContextUtils.getBean(LLMParserConfig.class);
|
||||
|
||||
URL url = new URL(new URL(llmParserConfig.getUrl()), llmParserConfig.getQueryToSqlPath());
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(llmReq), headers);
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
ResponseEntity<LLMResp> responseEntity = restTemplate.exchange(url.toString(), HttpMethod.POST, entity,
|
||||
LLMResp.class);
|
||||
|
||||
LLMResp llmResp = responseEntity.getBody();
|
||||
log.info("requestLLM response,cost:{}, questUrl:{} \n entity:{} \n body:{}",
|
||||
System.currentTimeMillis() - startTime, url, entity, llmResp);
|
||||
keyPipelineLog.info("PythonLLMProxy llmResp:{}", llmResp);
|
||||
|
||||
if (MapUtils.isEmpty(llmResp.getSqlRespMap())) {
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(new ArrayList<>(), llmResp.getSqlWeight()));
|
||||
}
|
||||
return llmResp;
|
||||
} catch (Exception e) {
|
||||
log.error("requestLLM error", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.headless.chat.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.boot.CommandLineRunner;
|
||||
import org.springframework.core.annotation.Order;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@Order(0)
|
||||
public class SqlEmbeddingListener implements CommandLineRunner {
|
||||
|
||||
@Autowired
|
||||
private ExemplarManager exemplarManager;
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
@Override
|
||||
public void run(String... args) {
|
||||
initSqlExamples();
|
||||
}
|
||||
|
||||
public void initSqlExamples() {
|
||||
try {
|
||||
if (ComponentFactory.getLLMProxy() instanceof JavaLLMProxy) {
|
||||
List<Exemplar> exemplars = exemplarManager.getExemplars();
|
||||
String collectionName = embeddingConfig.getText2sqlCollectionName();
|
||||
exemplarManager.addExemplars(exemplars, collectionName);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("initSqlExamples error", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.S2ChatModelProvider;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.InitializingBean;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
/**
|
||||
* SqlGenStrategy abstracts generation step so that
|
||||
* different LLM prompting strategies can be implemented.
|
||||
*/
|
||||
@Service
|
||||
public abstract class SqlGenStrategy implements InitializingBean {
|
||||
|
||||
protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
@Autowired
|
||||
protected PromptHelper promptHelper;
|
||||
|
||||
protected ChatLanguageModel getChatLanguageModel(LLMConfig llmConfig) {
|
||||
return S2ChatModelProvider.provide(llmConfig);
|
||||
}
|
||||
|
||||
abstract LLMResp generate(LLMReq llmReq);
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class SqlGenStrategyFactory {
|
||||
|
||||
private static Map<LLMReq.SqlGenType, SqlGenStrategy> sqlGenStrategyMap = new ConcurrentHashMap<>();
|
||||
|
||||
public static SqlGenStrategy get(LLMReq.SqlGenType strategyType) {
|
||||
return sqlGenStrategyMap.get(strategyType);
|
||||
}
|
||||
|
||||
public static void addSqlGenerationForFactory(LLMReq.SqlGenType strategy, SqlGenStrategy sqlGenStrategy) {
|
||||
sqlGenStrategyMap.put(strategy, sqlGenStrategy);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
@Service
|
||||
@Deprecated
|
||||
public class TwoPassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
//1.recall exemplars
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy llmReq:{}", llmReq);
|
||||
|
||||
List<List<Map<String, String>>> exampleListPool = promptHelper.getFewShotExemplars(llmReq);
|
||||
|
||||
//2.generate schema linking prompt for each self-consistency inference
|
||||
List<String> linkingPromptPool = new ArrayList<>();
|
||||
for (List<Map<String, String>> exampleList : exampleListPool) {
|
||||
String prompt = generateLinkingPrompt(llmReq, exampleList);
|
||||
linkingPromptPool.add(prompt);
|
||||
}
|
||||
|
||||
List<String> linkingResults = new CopyOnWriteArrayList<>();
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getLlmConfig());
|
||||
linkingPromptPool.parallelStream().forEach(
|
||||
linkingPrompt -> {
|
||||
Prompt prompt = PromptTemplate.from(JsonUtil.toString(linkingPrompt)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one reqPrompt:{}", prompt.toSystemMessage());
|
||||
Response<AiMessage> linkingResult = chatLanguageModel.generate(prompt.toSystemMessage());
|
||||
String result = linkingResult.content().text();
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step one modelResp:{}", result);
|
||||
linkingResults.add(OutputFormat.getSchemaLink(result));
|
||||
}
|
||||
);
|
||||
List<String> sortedList = OutputFormat.formatList(linkingResults);
|
||||
|
||||
//3.generate sql generation prompt for each self-consistency inference
|
||||
List<String> sqlPromptPool = new ArrayList<>();
|
||||
for (int i = 0; i < sortedList.size(); i++) {
|
||||
String schemaLinkStr = sortedList.get(i);
|
||||
List<Map<String, String>> fewshotExampleList = exampleListPool.get(i);
|
||||
String sqlPrompt = generateSqlPrompt(llmReq, schemaLinkStr, fewshotExampleList);
|
||||
sqlPromptPool.add(sqlPrompt);
|
||||
}
|
||||
|
||||
//4.perform multiple self-consistency inferences parallelly
|
||||
List<String> sqlTaskPool = new CopyOnWriteArrayList<>();
|
||||
sqlPromptPool.parallelStream().forEach(sqlPrompt -> {
|
||||
Prompt linkingPrompt = PromptTemplate.from(JsonUtil.toString(sqlPrompt)).apply(new HashMap<>());
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two reqPrompt:{}", linkingPrompt.toSystemMessage());
|
||||
Response<AiMessage> sqlResult = chatLanguageModel.generate(linkingPrompt.toSystemMessage());
|
||||
String result = sqlResult.content().text();
|
||||
keyPipelineLog.info("TwoPassSCSqlGenStrategy step two modelResp:{}", result);
|
||||
sqlTaskPool.add(result);
|
||||
});
|
||||
|
||||
//5.format response.
|
||||
Pair<String, Map<String, Double>> sqlMapPair = OutputFormat.selfConsistencyVote(sqlTaskPool);
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
|
||||
llmResp.setSqlRespMap(OutputFormat.buildSqlRespMap(exampleListPool.get(0), sqlMapPair.getRight()));
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
private String generateLinkingPrompt(LLMReq llmReq, List<Map<String, String>> exampleList) {
|
||||
String instruction = "# Find the schema_links for generating SQL queries for each question "
|
||||
+ "based on the database schema and Foreign keys.";
|
||||
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkingCoT");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\nA: generatedSchemaLinkingCoT";
|
||||
String exampleFormat = InputFormat.format(exampleTemplate, exampleKeys, exampleList);
|
||||
|
||||
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String newCaseTemplate = "%s\nQ: %s\nA: Let’s think step by step. In the question \"%s\", we are asked:";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, questionAugmented);
|
||||
|
||||
return instruction + InputFormat.SEPERATOR + exampleFormat + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
private String generateSqlPrompt(LLMReq llmReq, String schemaLinkStr,
|
||||
List<Map<String, String>> fewshotExampleList) {
|
||||
String instruction = "# Use the the schema links to generate the SQL queries for each of the questions.";
|
||||
List<String> exampleKeys = Arrays.asList("questionAugmented", "dbSchema", "generatedSchemaLinkings", "sql");
|
||||
String exampleTemplate = "dbSchema\nQ: questionAugmented\n" + "Schema_links: generatedSchemaLinkings\n"
|
||||
+ "SQL: sql";
|
||||
|
||||
String schemaLinkingPrompt = InputFormat.format(exampleTemplate, exampleKeys, fewshotExampleList);
|
||||
Pair<String, String> questionPrompt = promptHelper.transformQuestionPrompt(llmReq);
|
||||
String dbSchema = questionPrompt.getLeft();
|
||||
String questionAugmented = questionPrompt.getRight();
|
||||
String newCaseTemplate = "%s\nQ: %s\nSchema_links: %s\nSQL: ";
|
||||
String newCasePrompt = String.format(newCaseTemplate, dbSchema, questionAugmented, schemaLinkStr);
|
||||
return instruction + InputFormat.SEPERATOR + schemaLinkingPrompt + InputFormat.SEPERATOR + newCasePrompt;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterPropertiesSet() {
|
||||
SqlGenStrategyFactory.addSqlGenerationForFactory(LLMReq.SqlGenType.TWO_PASS_AUTO_COT_SELF_CONSISTENCY, this);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.COUNT;
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.DISTINCT;
|
||||
|
||||
/**
|
||||
* AggregateTypeParser extracts aggregation type specified in the user query
|
||||
* based on keyword matching.
|
||||
* Currently, it supports 7 types of aggregation: max, min, sum, avg, topN,
|
||||
* distinct count, count.
|
||||
*/
|
||||
@Slf4j
|
||||
public class AggregateTypeParser implements SemanticParser {
|
||||
|
||||
private static final Map<AggregateTypeEnum, Pattern> REGX_MAP = Stream.of(
|
||||
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.MAX, Pattern.compile("(?i)(最大值|最大|max|峰值|最高|最多)")),
|
||||
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.MIN, Pattern.compile("(?i)(最小值|最小|min|最低|最少)")),
|
||||
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.SUM, Pattern.compile("(?i)(汇总|总和|sum)")),
|
||||
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.AVG, Pattern.compile("(?i)(平均值|日均|平均|avg)")),
|
||||
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.TOPN, Pattern.compile("(?i)(top)")),
|
||||
new AbstractMap.SimpleEntry<>(DISTINCT, Pattern.compile("(?i)(uv)")),
|
||||
new AbstractMap.SimpleEntry<>(COUNT, Pattern.compile("(?i)(总数|pv)")),
|
||||
new AbstractMap.SimpleEntry<>(AggregateTypeEnum.NONE, Pattern.compile("(?i)(明细)"))
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
String queryText = queryContext.getQueryText();
|
||||
AggregateConf aggregateConf = resolveAggregateConf(queryText);
|
||||
|
||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||
if (!AggregateTypeEnum.NONE.equals(semanticQuery.getParseInfo().getAggType())) {
|
||||
continue;
|
||||
}
|
||||
semanticQuery.getParseInfo().setAggType(aggregateConf.type);
|
||||
int detectWordLength = 0;
|
||||
if (StringUtils.isNotEmpty(aggregateConf.detectWord)) {
|
||||
detectWordLength = aggregateConf.detectWord.length();
|
||||
}
|
||||
semanticQuery.getParseInfo().setScore(semanticQuery.getParseInfo().getScore() + detectWordLength);
|
||||
}
|
||||
}
|
||||
|
||||
public AggregateTypeEnum resolveAggregateType(String queryText) {
|
||||
AggregateConf aggregateConf = resolveAggregateConf(queryText);
|
||||
return aggregateConf.type;
|
||||
}
|
||||
|
||||
private AggregateConf resolveAggregateConf(String queryText) {
|
||||
Map<AggregateTypeEnum, Integer> aggregateCount = new HashMap<>(REGX_MAP.size());
|
||||
Map<AggregateTypeEnum, String> aggregateWord = new HashMap<>(REGX_MAP.size());
|
||||
|
||||
|
||||
for (Map.Entry<AggregateTypeEnum, Pattern> entry : REGX_MAP.entrySet()) {
|
||||
Matcher matcher = entry.getValue().matcher(queryText);
|
||||
int count = 0;
|
||||
String detectWord = null;
|
||||
while (matcher.find()) {
|
||||
count++;
|
||||
detectWord = matcher.group();
|
||||
}
|
||||
if (count > 0) {
|
||||
aggregateCount.put(entry.getKey(), count);
|
||||
aggregateWord.put(entry.getKey(), detectWord);
|
||||
}
|
||||
}
|
||||
|
||||
AggregateTypeEnum type = aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue())
|
||||
.map(entry -> entry.getKey()).orElse(AggregateTypeEnum.NONE);
|
||||
String detectWord = aggregateWord.get(type);
|
||||
return new AggregateConf(type, detectWord);
|
||||
}
|
||||
|
||||
@AllArgsConstructor
|
||||
class AggregateConf {
|
||||
public AggregateTypeEnum type;
|
||||
public String detectWord;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.AbstractMap;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
* ContextInheritParser tries to inherit certain schema elements from context
|
||||
* so that in multi-turn conversations users don't need to mention some keyword
|
||||
* repeatedly.
|
||||
*/
|
||||
@Slf4j
|
||||
public class ContextInheritParser implements SemanticParser {
|
||||
|
||||
private static final Map<SchemaElementType, List<SchemaElementType>> MUTUAL_EXCLUSIVE_MAP = Stream.of(
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.METRIC, Arrays.asList(SchemaElementType.METRIC)),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
SchemaElementType.DIMENSION, Arrays.asList(SchemaElementType.DIMENSION, SchemaElementType.VALUE)),
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
SchemaElementType.VALUE, Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY, Arrays.asList(SchemaElementType.ENTITY)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.DATASET, Arrays.asList(SchemaElementType.DATASET)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ID, Arrays.asList(SchemaElementType.ID))
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
if (!shouldInherit(queryContext)) {
|
||||
return;
|
||||
}
|
||||
Long dataSetId = getMatchedDataSet(queryContext, chatContext);
|
||||
if (dataSetId == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<SchemaElementMatch> elementMatches = queryContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
|
||||
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
||||
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
||||
SchemaElementType matchType = match.getElement().getType();
|
||||
// mutual exclusive element types should not be inherited
|
||||
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(chatContext.getParseInfo().getQueryMode());
|
||||
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
|
||||
match.setInherited(true);
|
||||
matchesToInherit.add(match);
|
||||
}
|
||||
}
|
||||
elementMatches.addAll(matchesToInherit);
|
||||
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, queryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(queryContext, chatContext);
|
||||
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), queryContext)) {
|
||||
continue;
|
||||
}
|
||||
queryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
}
|
||||
|
||||
private boolean existSameQuery(Long dataSetId, String queryMode, QueryContext queryContext) {
|
||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||
if (semanticQuery.getQueryMode().equals(queryMode)
|
||||
&& semanticQuery.getParseInfo().getDataSetId().equals(dataSetId)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private boolean containsTypes(List<SchemaElementMatch> matches, SchemaElementType matchType,
|
||||
RuleSemanticQuery ruleQuery) {
|
||||
List<SchemaElementType> types = MUTUAL_EXCLUSIVE_MAP.get(matchType);
|
||||
|
||||
return matches.stream().anyMatch(m -> {
|
||||
SchemaElementType type = m.getElement().getType();
|
||||
if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery
|
||||
&& !(ruleQuery instanceof MetricIdQuery)) {
|
||||
return types.contains(type);
|
||||
}
|
||||
return type.equals(matchType);
|
||||
});
|
||||
}
|
||||
|
||||
protected boolean shouldInherit(QueryContext queryContext) {
|
||||
// if candidates only have MetricModel mode, count in context
|
||||
List<SemanticQuery> metricModelQueries = queryContext.getCandidateQueries().stream()
|
||||
.filter(query -> query instanceof MetricModelQuery).collect(
|
||||
Collectors.toList());
|
||||
return metricModelQueries.size() == queryContext.getCandidateQueries().size();
|
||||
}
|
||||
|
||||
protected Long getMatchedDataSet(QueryContext queryContext, ChatContext chatContext) {
|
||||
Long dataSetId = chatContext.getParseInfo().getDataSetId();
|
||||
if (dataSetId == null) {
|
||||
return null;
|
||||
}
|
||||
Set<Long> queryDataSets = queryContext.getMapInfo().getMatchedDataSetInfos();
|
||||
if (queryDataSets.contains(dataSetId)) {
|
||||
return dataSetId;
|
||||
}
|
||||
return dataSetId;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* RuleSqlParser resolves a specific SemanticQuery according to co-appearance
|
||||
* of certain schema element types.
|
||||
*/
|
||||
@Slf4j
|
||||
public class RuleSqlParser implements SemanticParser {
|
||||
|
||||
private static List<SemanticParser> auxiliaryParsers = Arrays.asList(
|
||||
new ContextInheritParser(),
|
||||
new TimeRangeParser(),
|
||||
new AggregateTypeParser()
|
||||
);
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
if (!queryContext.getText2SQLType().enableRule()) {
|
||||
return;
|
||||
}
|
||||
SchemaMapInfo mapInfo = queryContext.getMapInfo();
|
||||
// iterate all schemaElementMatches to resolve query mode
|
||||
for (Long dataSetId : mapInfo.getMatchedDataSetInfos()) {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, queryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(queryContext, chatContext);
|
||||
queryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
}
|
||||
|
||||
auxiliaryParsers.stream().forEach(p -> p.parse(queryContext, chatContext));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.xkzhangsan.time.nlp.TimeNLP;
|
||||
import com.xkzhangsan.time.nlp.TimeNLPUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
|
||||
import java.text.DateFormat;
|
||||
import java.text.ParseException;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.time.LocalDate;
|
||||
import java.util.Date;
|
||||
import java.util.List;
|
||||
import java.util.Stack;
|
||||
import java.util.regex.Matcher;
|
||||
import java.util.regex.Pattern;
|
||||
|
||||
/**
|
||||
* TimeRangeParser extracts time range specified in the user query
|
||||
* based on keyword matching.
|
||||
* Currently, it supports two kinds of expression:
|
||||
* 1. Recent unit: 近N天/周/月/年、过去N天/周/月/年
|
||||
* 2. Concrete date: 2023年11月15日、20231115
|
||||
*/
|
||||
@Slf4j
|
||||
public class TimeRangeParser implements SemanticParser {
|
||||
|
||||
private static final Pattern RECENT_PATTERN_CN = Pattern.compile(
|
||||
".*(?<periodStr>(近|过去)((?<enNum>\\d+)|(?<zhNum>[一二三四五六七八九十百千万亿]+))个?(?<zhPeriod>[天周月年])).*");
|
||||
private static final Pattern DATE_PATTERN_NUMBER = Pattern.compile("(\\d{8})");
|
||||
private static final DateFormat DATE_FORMAT_NUMBER = new SimpleDateFormat("yyyyMMdd");
|
||||
private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd");
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
String queryText = queryContext.getQueryText();
|
||||
DateConf dateConf = parseRecent(queryText);
|
||||
if (dateConf == null) {
|
||||
dateConf = parseDateNumber(queryText);
|
||||
}
|
||||
if (dateConf == null) {
|
||||
dateConf = parseDateCN(queryText);
|
||||
}
|
||||
|
||||
if (dateConf != null) {
|
||||
if (queryContext.getCandidateQueries().size() > 0) {
|
||||
for (SemanticQuery query : queryContext.getCandidateQueries()) {
|
||||
query.getParseInfo().setDateInfo(dateConf);
|
||||
query.getParseInfo().setScore(query.getParseInfo().getScore()
|
||||
+ dateConf.getDetectWord().length());
|
||||
}
|
||||
} else if (QueryManager.containsRuleQuery(chatContext.getParseInfo().getQueryMode())) {
|
||||
RuleSemanticQuery semanticQuery = QueryManager.createRuleQuery(
|
||||
chatContext.getParseInfo().getQueryMode());
|
||||
// inherit parse info from context
|
||||
chatContext.getParseInfo().setDateInfo(dateConf);
|
||||
chatContext.getParseInfo().setScore(chatContext.getParseInfo().getScore()
|
||||
+ dateConf.getDetectWord().length());
|
||||
semanticQuery.setParseInfo(chatContext.getParseInfo());
|
||||
queryContext.getCandidateQueries().add(semanticQuery);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private DateConf parseDateCN(String queryText) {
|
||||
Date startDate = null;
|
||||
Date endDate;
|
||||
String detectWord = null;
|
||||
|
||||
List<TimeNLP> times = TimeNLPUtil.parse(queryText);
|
||||
if (times.size() > 0) {
|
||||
startDate = times.get(0).getTime();
|
||||
detectWord = times.get(0).getTimeExpression();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (times.size() > 1) {
|
||||
endDate = times.get(1).getTime();
|
||||
detectWord += "~" + times.get(0).getTimeExpression();
|
||||
} else {
|
||||
endDate = startDate;
|
||||
}
|
||||
|
||||
return getDateConf(startDate, endDate, detectWord);
|
||||
}
|
||||
|
||||
private DateConf parseDateNumber(String queryText) {
|
||||
String startDate;
|
||||
String endDate = null;
|
||||
String detectWord = null;
|
||||
|
||||
Matcher dateMatcher = DATE_PATTERN_NUMBER.matcher(queryText);
|
||||
if (dateMatcher.find()) {
|
||||
startDate = dateMatcher.group();
|
||||
detectWord = startDate;
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (dateMatcher.find()) {
|
||||
endDate = dateMatcher.group();
|
||||
detectWord += "~" + endDate;
|
||||
}
|
||||
|
||||
endDate = endDate != null ? endDate : startDate;
|
||||
|
||||
try {
|
||||
return getDateConf(DATE_FORMAT_NUMBER.parse(startDate), DATE_FORMAT_NUMBER.parse(endDate), detectWord);
|
||||
} catch (ParseException e) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private DateConf parseRecent(String queryText) {
|
||||
Matcher m = RECENT_PATTERN_CN.matcher(queryText);
|
||||
if (m.matches()) {
|
||||
int num = 0;
|
||||
String enNum = m.group("enNum");
|
||||
String zhNum = m.group("zhNum");
|
||||
if (enNum != null) {
|
||||
num = Integer.parseInt(enNum);
|
||||
} else if (zhNum != null) {
|
||||
num = zhNumParse(zhNum);
|
||||
}
|
||||
if (num > 0) {
|
||||
DateConf info = new DateConf();
|
||||
String zhPeriod = m.group("zhPeriod");
|
||||
int days;
|
||||
switch (zhPeriod) {
|
||||
case "周":
|
||||
days = 7;
|
||||
info.setPeriod(Constants.WEEK);
|
||||
break;
|
||||
case "月":
|
||||
days = 30;
|
||||
info.setPeriod(Constants.MONTH);
|
||||
break;
|
||||
case "年":
|
||||
days = 365;
|
||||
info.setPeriod(Constants.YEAR);
|
||||
break;
|
||||
default:
|
||||
days = 1;
|
||||
info.setPeriod(Constants.DAY);
|
||||
}
|
||||
days = days * num;
|
||||
info.setDateMode(DateConf.DateMode.RECENT);
|
||||
String detectWord = "近" + num + zhPeriod;
|
||||
if (Strings.isNotEmpty(m.group("periodStr"))) {
|
||||
detectWord = m.group("periodStr");
|
||||
}
|
||||
info.setDetectWord(detectWord);
|
||||
info.setStartDate(LocalDate.now().minusDays(days).toString());
|
||||
info.setEndDate(LocalDate.now().minusDays(1).toString());
|
||||
info.setUnit(num);
|
||||
|
||||
return info;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
private int zhNumParse(String zhNumStr) {
|
||||
Stack<Integer> stack = new Stack<>();
|
||||
String numStr = "一二三四五六七八九";
|
||||
String unitStr = "十百千万亿";
|
||||
|
||||
String[] ssArr = zhNumStr.split("");
|
||||
for (String e : ssArr) {
|
||||
int numIndex = numStr.indexOf(e);
|
||||
int unitIndex = unitStr.indexOf(e);
|
||||
if (numIndex != -1) {
|
||||
stack.push(numIndex + 1);
|
||||
} else if (unitIndex != -1) {
|
||||
int unitNum = (int) Math.pow(10, unitIndex + 1);
|
||||
if (stack.isEmpty()) {
|
||||
stack.push(unitNum);
|
||||
} else {
|
||||
stack.push(stack.pop() * unitNum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stack.stream().mapToInt(s -> s).sum();
|
||||
}
|
||||
|
||||
private DateConf getDateConf(Date startDate, Date endDate, String detectWord) {
|
||||
if (startDate == null || endDate == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
DateConf info = new DateConf();
|
||||
info.setDateMode(DateConf.DateMode.BETWEEN);
|
||||
info.setStartDate(DATE_FORMAT.format(startDate));
|
||||
info.setEndDate(DATE_FORMAT.format(endDate));
|
||||
info.setDetectWord(detectWord);
|
||||
return info;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
|
||||
package com.tencent.supersonic.headless.chat.query;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.Filter;
|
||||
import com.tencent.supersonic.common.pojo.Order;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_S2SQL_ENABLE;
|
||||
|
||||
@Slf4j
|
||||
@ToString
|
||||
public abstract class BaseSemanticQuery implements SemanticQuery, Serializable {
|
||||
|
||||
protected SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
|
||||
@Override
|
||||
public SemanticParseInfo getParseInfo() {
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setParseInfo(SemanticParseInfo parseInfo) {
|
||||
this.parseInfo = parseInfo;
|
||||
}
|
||||
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||
}
|
||||
|
||||
protected void convertBizNameToName(SemanticSchema semanticSchema, QueryStructReq queryStructReq) {
|
||||
Map<String, String> bizNameToName = semanticSchema.getBizNameToName(queryStructReq.getDataSetId());
|
||||
bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap());
|
||||
|
||||
List<Order> orders = queryStructReq.getOrders();
|
||||
if (CollectionUtils.isNotEmpty(orders)) {
|
||||
for (Order order : orders) {
|
||||
order.setColumn(bizNameToName.get(order.getColumn()));
|
||||
}
|
||||
}
|
||||
List<Aggregator> aggregators = queryStructReq.getAggregators();
|
||||
if (CollectionUtils.isNotEmpty(aggregators)) {
|
||||
for (Aggregator aggregator : aggregators) {
|
||||
aggregator.setColumn(bizNameToName.get(aggregator.getColumn()));
|
||||
}
|
||||
}
|
||||
List<String> groups = queryStructReq.getGroups();
|
||||
if (CollectionUtils.isNotEmpty(groups)) {
|
||||
groups = groups.stream().map(bizNameToName::get).collect(Collectors.toList());
|
||||
queryStructReq.setGroups(groups);
|
||||
}
|
||||
List<Filter> dimensionFilters = queryStructReq.getDimensionFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
dimensionFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
List<Filter> metricFilters = queryStructReq.getMetricFilters();
|
||||
if (CollectionUtils.isNotEmpty(dimensionFilters)) {
|
||||
metricFilters.forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName())));
|
||||
}
|
||||
}
|
||||
|
||||
protected void initS2SqlByStruct(SemanticSchema semanticSchema) {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
boolean s2sqlEnable = Boolean.valueOf(parserConfig.getParameterValue(PARSER_S2SQL_ENABLE));
|
||||
if (!s2sqlEnable) {
|
||||
return;
|
||||
}
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
convertBizNameToName(semanticSchema, queryStructReq);
|
||||
QuerySqlReq querySQLReq = queryStructReq.convert();
|
||||
parseInfo.getSqlInfo().setS2SQL(querySQLReq.getSql());
|
||||
parseInfo.getSqlInfo().setCorrectS2SQL(querySQLReq.getSql());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
package com.tencent.supersonic.headless.chat.query;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.detail.DetailSemanticQuery;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
public class QueryManager {
|
||||
|
||||
private static Map<String, RuleSemanticQuery> ruleQueryMap = new ConcurrentHashMap<>();
|
||||
private static Map<String, LLMSemanticQuery> llmQueryMap = new ConcurrentHashMap<>();
|
||||
|
||||
public static void register(SemanticQuery query) {
|
||||
if (query instanceof RuleSemanticQuery) {
|
||||
ruleQueryMap.put(query.getQueryMode(), (RuleSemanticQuery) query);
|
||||
} else if (query instanceof LLMSemanticQuery) {
|
||||
llmQueryMap.put(query.getQueryMode(), (LLMSemanticQuery) query);
|
||||
}
|
||||
}
|
||||
|
||||
public static SemanticQuery createQuery(String queryMode) {
|
||||
if (containsRuleQuery(queryMode)) {
|
||||
return createRuleQuery(queryMode);
|
||||
}
|
||||
return createLLMQuery(queryMode);
|
||||
|
||||
}
|
||||
|
||||
public static RuleSemanticQuery createRuleQuery(String queryMode) {
|
||||
RuleSemanticQuery semanticQuery = ruleQueryMap.get(queryMode);
|
||||
return (RuleSemanticQuery) getSemanticQuery(queryMode, semanticQuery);
|
||||
}
|
||||
|
||||
public static LLMSemanticQuery createLLMQuery(String queryMode) {
|
||||
LLMSemanticQuery semanticQuery = llmQueryMap.get(queryMode);
|
||||
return (LLMSemanticQuery) getSemanticQuery(queryMode, semanticQuery);
|
||||
}
|
||||
|
||||
private static SemanticQuery getSemanticQuery(String queryMode, SemanticQuery semanticQuery) {
|
||||
if (Objects.isNull(semanticQuery)) {
|
||||
throw new RuntimeException("no supported queryMode :" + queryMode);
|
||||
}
|
||||
try {
|
||||
return semanticQuery.getClass().getDeclaredConstructor().newInstance();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("no supported queryMode :" + queryMode);
|
||||
}
|
||||
}
|
||||
|
||||
public static boolean containsRuleQuery(String queryMode) {
|
||||
if (queryMode == null) {
|
||||
return false;
|
||||
}
|
||||
return ruleQueryMap.containsKey(queryMode);
|
||||
}
|
||||
|
||||
public static boolean isMetricQuery(String queryMode) {
|
||||
if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) {
|
||||
return false;
|
||||
}
|
||||
return ruleQueryMap.get(queryMode) instanceof MetricSemanticQuery;
|
||||
}
|
||||
|
||||
public static boolean isTagQuery(String queryMode) {
|
||||
if (queryMode == null || !ruleQueryMap.containsKey(queryMode)) {
|
||||
return false;
|
||||
}
|
||||
return ruleQueryMap.get(queryMode) instanceof DetailSemanticQuery;
|
||||
}
|
||||
|
||||
public static RuleSemanticQuery getRuleQuery(String queryMode) {
|
||||
if (queryMode == null) {
|
||||
return null;
|
||||
}
|
||||
return ruleQueryMap.get(queryMode);
|
||||
}
|
||||
|
||||
public static List<RuleSemanticQuery> getRuleQueries() {
|
||||
return new ArrayList<>(ruleQueryMap.values());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.tencent.supersonic.headless.chat.query;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
|
||||
/**
|
||||
* A semantic query executes specific type of query based on the results of semantic parsing.
|
||||
*/
|
||||
public interface SemanticQuery {
|
||||
|
||||
String getQueryMode();
|
||||
|
||||
SemanticQueryReq buildSemanticQueryReq() throws SqlParseException;
|
||||
|
||||
void initS2Sql(SemanticSchema semanticSchema, User user);
|
||||
|
||||
SemanticParseInfo getParseInfo();
|
||||
|
||||
void setParseInfo(SemanticParseInfo parseInfo);
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public abstract class LLMSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
public class LLMReq {
|
||||
|
||||
private String queryText;
|
||||
|
||||
private FilterCondition filterCondition;
|
||||
|
||||
private LLMSchema schema;
|
||||
|
||||
private List<ElementValue> linking;
|
||||
|
||||
private String currentDate;
|
||||
|
||||
private String priorExts;
|
||||
|
||||
private SqlGenType sqlGenType;
|
||||
|
||||
private LLMConfig llmConfig;
|
||||
|
||||
@Data
|
||||
public static class ElementValue {
|
||||
|
||||
private String fieldName;
|
||||
|
||||
private String fieldValue;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class LLMSchema {
|
||||
|
||||
private String domainName;
|
||||
|
||||
private String dataSetName;
|
||||
|
||||
private Long dataSetId;
|
||||
|
||||
private List<String> fieldNameList;
|
||||
|
||||
private List<SchemaElement> metrics;
|
||||
|
||||
private List<SchemaElement> dimensions;
|
||||
|
||||
private List<Term> terms;
|
||||
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class FilterCondition {
|
||||
|
||||
private String tableName;
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class Term {
|
||||
|
||||
private String name;
|
||||
|
||||
private String description;
|
||||
|
||||
private List<String> alias = Lists.newArrayList();
|
||||
|
||||
}
|
||||
|
||||
public enum SqlGenType {
|
||||
ONE_PASS_SELF_CONSISTENCY("1_pass_self_consistency"),
|
||||
TWO_PASS_AUTO_COT_SELF_CONSISTENCY("2_pass_auto_cot_self_consistency");
|
||||
|
||||
private String name;
|
||||
|
||||
SqlGenType(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
@JsonValue
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
public class LLMResp {
|
||||
|
||||
private String query;
|
||||
|
||||
private String modelName;
|
||||
|
||||
private String sqlOutput;
|
||||
|
||||
private List<String> fields;
|
||||
|
||||
private Map<String, LLMSqlResp> sqlRespMap;
|
||||
|
||||
/**
|
||||
* Only for compatibility with python code, later deleted
|
||||
*/
|
||||
private Map<String, Double> sqlWeight;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class LLMSqlQuery extends LLMSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "LLM_S2SQL";
|
||||
|
||||
public LLMSqlQuery() {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticQueryReq buildSemanticQueryReq() {
|
||||
|
||||
String querySql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
return QueryReqBuilder.buildS2SQLReq(querySql, parseInfo.getDataSetId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initS2Sql(SemanticSchema semanticSchema, User user) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
sqlInfo.setCorrectS2SQL(sqlInfo.getS2SQL());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class LLMSqlResp {
|
||||
|
||||
private double sqlWeight;
|
||||
|
||||
private List<Map<String, String>> fewShots;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class QueryMatchOption {
|
||||
|
||||
private OptionType schemaElementOption;
|
||||
private RequireNumberType requireNumberType;
|
||||
private Integer requireNumber;
|
||||
|
||||
public static QueryMatchOption build(OptionType schemaElementOption,
|
||||
RequireNumberType requireNumberType, Integer requireNumber) {
|
||||
QueryMatchOption queryMatchOption = new QueryMatchOption();
|
||||
queryMatchOption.requireNumber = requireNumber;
|
||||
queryMatchOption.requireNumberType = requireNumberType;
|
||||
queryMatchOption.schemaElementOption = schemaElementOption;
|
||||
return queryMatchOption;
|
||||
}
|
||||
|
||||
public static QueryMatchOption optional() {
|
||||
QueryMatchOption queryMatchOption = new QueryMatchOption();
|
||||
queryMatchOption.setSchemaElementOption(OptionType.OPTIONAL);
|
||||
queryMatchOption.setRequireNumber(0);
|
||||
queryMatchOption.setRequireNumberType(RequireNumberType.AT_LEAST);
|
||||
return queryMatchOption;
|
||||
}
|
||||
|
||||
public static QueryMatchOption unused() {
|
||||
QueryMatchOption queryMatchOption = new QueryMatchOption();
|
||||
queryMatchOption.setSchemaElementOption(OptionType.UNUSED);
|
||||
queryMatchOption.setRequireNumber(0);
|
||||
queryMatchOption.setRequireNumberType(RequireNumberType.EQUAL);
|
||||
return queryMatchOption;
|
||||
}
|
||||
|
||||
public enum RequireNumberType {
|
||||
AT_MOST, AT_LEAST, EQUAL
|
||||
}
|
||||
|
||||
public enum OptionType {
|
||||
REQUIRED, OPTIONAL, UNUSED
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
public class QueryMatcher {
|
||||
|
||||
private HashMap<SchemaElementType, QueryMatchOption> elementOptionMap = new HashMap<>();
|
||||
private boolean supportCompare;
|
||||
private boolean supportOrderBy;
|
||||
private List<AggregateTypeEnum> orderByTypes = Arrays.asList(AggregateTypeEnum.MAX, AggregateTypeEnum.MIN,
|
||||
AggregateTypeEnum.TOPN);
|
||||
|
||||
public QueryMatcher() {
|
||||
for (SchemaElementType type : SchemaElementType.values()) {
|
||||
if (type.equals(SchemaElementType.DATASET)) {
|
||||
elementOptionMap.put(type, QueryMatchOption.optional());
|
||||
} else {
|
||||
elementOptionMap.put(type, QueryMatchOption.unused());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public QueryMatcher addOption(SchemaElementType type, QueryMatchOption.OptionType option,
|
||||
QueryMatchOption.RequireNumberType requireNumberType, Integer requireNumber) {
|
||||
elementOptionMap.put(type, QueryMatchOption.build(option, requireNumberType, requireNumber));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* Match schema element with current query according to the options.
|
||||
*
|
||||
* @param candidateElementMatches
|
||||
* @return a list of all matched schema elements,
|
||||
* empty list if no matches can be found
|
||||
*/
|
||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches) {
|
||||
List<SchemaElementMatch> elementMatches = new ArrayList<>();
|
||||
HashMap<SchemaElementType, Integer> schemaElementTypeCount = new HashMap<>();
|
||||
for (SchemaElementMatch schemaElementMatch : candidateElementMatches) {
|
||||
SchemaElementType schemaElementType = schemaElementMatch.getElement().getType();
|
||||
if (schemaElementTypeCount.containsKey(schemaElementType)) {
|
||||
schemaElementTypeCount.put(schemaElementType, schemaElementTypeCount.get(schemaElementType) + 1);
|
||||
} else {
|
||||
schemaElementTypeCount.put(schemaElementType, 1);
|
||||
}
|
||||
}
|
||||
|
||||
// check if current query options are satisfied, return immediately if not
|
||||
for (Map.Entry<SchemaElementType, QueryMatchOption> e : elementOptionMap.entrySet()) {
|
||||
SchemaElementType elementType = e.getKey();
|
||||
QueryMatchOption elementOption = e.getValue();
|
||||
if (!isMatch(elementOption, getCount(schemaElementTypeCount, elementType))) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
}
|
||||
|
||||
// add element match if its element type is not declared as unused
|
||||
for (SchemaElementMatch elementMatch : candidateElementMatches) {
|
||||
QueryMatchOption elementOption = elementOptionMap.get(elementMatch.getElement().getType());
|
||||
if (Objects.nonNull(elementOption) && !elementOption.getSchemaElementOption()
|
||||
.equals(QueryMatchOption.OptionType.UNUSED)) {
|
||||
elementMatches.add(elementMatch);
|
||||
}
|
||||
}
|
||||
|
||||
return elementMatches;
|
||||
}
|
||||
|
||||
private int getCount(HashMap<SchemaElementType, Integer> schemaElementTypeCount,
|
||||
SchemaElementType schemaElementType) {
|
||||
if (schemaElementTypeCount.containsKey(schemaElementType)) {
|
||||
return schemaElementTypeCount.get(schemaElementType);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
private boolean isMatch(QueryMatchOption queryMatchOption, int count) {
|
||||
// check if required but empty
|
||||
if (queryMatchOption.getSchemaElementOption().equals(QueryMatchOption.OptionType.REQUIRED) && count <= 0) {
|
||||
return false;
|
||||
}
|
||||
if (queryMatchOption.getRequireNumberType().equals(QueryMatchOption.RequireNumberType.AT_LEAST)
|
||||
&& count < queryMatchOption.getRequireNumber()) {
|
||||
return false;
|
||||
}
|
||||
if (queryMatchOption.getRequireNumberType().equals(QueryMatchOption.RequireNumberType.AT_MOST)
|
||||
&& count > queryMatchOption.getRequireNumber()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
|
||||
package com.tencent.supersonic.headless.chat.query.rule;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.chat.QueryContext;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
@ToString
|
||||
public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
|
||||
protected QueryMatcher queryMatcher = new QueryMatcher();
|
||||
|
||||
public RuleSemanticQuery() {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
||||
QueryContext queryCtx) {
|
||||
return queryMatcher.match(candidateElementMatches);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initS2Sql(SemanticSchema semanticSchema, User user) {
|
||||
initS2SqlByStruct(semanticSchema);
|
||||
}
|
||||
|
||||
public void fillParseInfo(QueryContext queryContext, ChatContext chatContext) {
|
||||
parseInfo.setQueryMode(getQueryMode());
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
|
||||
fillSchemaElement(parseInfo, semanticSchema);
|
||||
fillScore(parseInfo);
|
||||
fillDateConf(parseInfo, chatContext.getParseInfo());
|
||||
}
|
||||
|
||||
private void fillDateConf(SemanticParseInfo queryParseInfo, SemanticParseInfo chatParseInfo) {
|
||||
if (queryParseInfo.getDateInfo() != null || chatParseInfo.getDateInfo() == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if ((QueryManager.isTagQuery(queryParseInfo.getQueryMode())
|
||||
&& QueryManager.isTagQuery(chatParseInfo.getQueryMode()))
|
||||
|| (QueryManager.isMetricQuery(queryParseInfo.getQueryMode())
|
||||
&& QueryManager.isMetricQuery(chatParseInfo.getQueryMode()))) {
|
||||
// inherit date info from context
|
||||
queryParseInfo.setDateInfo(chatParseInfo.getDateInfo());
|
||||
queryParseInfo.getDateInfo().setInherited(true);
|
||||
}
|
||||
}
|
||||
|
||||
private void fillScore(SemanticParseInfo parseInfo) {
|
||||
double totalScore = 0;
|
||||
|
||||
Map<SchemaElementType, SchemaElementMatch> maxSimilarityMatch = new HashMap<>();
|
||||
for (SchemaElementMatch match : parseInfo.getElementMatches()) {
|
||||
SchemaElementType type = match.getElement().getType();
|
||||
if (!maxSimilarityMatch.containsKey(type)
|
||||
|| match.getSimilarity() > maxSimilarityMatch.get(type).getSimilarity()) {
|
||||
maxSimilarityMatch.put(type, match);
|
||||
}
|
||||
}
|
||||
|
||||
for (SchemaElementMatch match : maxSimilarityMatch.values()) {
|
||||
totalScore += match.getDetectWord().length() * match.getSimilarity();
|
||||
}
|
||||
|
||||
parseInfo.setScore(parseInfo.getScore() + totalScore);
|
||||
}
|
||||
|
||||
private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) {
|
||||
Set<Long> dataSetIds = parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement)
|
||||
.map(SchemaElement::getDataSet).collect(Collectors.toSet());
|
||||
Long dataSetId = dataSetIds.iterator().next();
|
||||
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
|
||||
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
|
||||
Map<Long, List<SchemaElementMatch>> id2Values = new HashMap<>();
|
||||
|
||||
for (SchemaElementMatch schemaMatch : parseInfo.getElementMatches()) {
|
||||
SchemaElement element = schemaMatch.getElement();
|
||||
element.setOrder(1 - schemaMatch.getSimilarity());
|
||||
switch (element.getType()) {
|
||||
case ID:
|
||||
SchemaElement entityElement = semanticSchema.getElement(SchemaElementType.ENTITY, element.getId());
|
||||
if (entityElement != null) {
|
||||
if (id2Values.containsKey(element.getId())) {
|
||||
id2Values.get(element.getId()).add(schemaMatch);
|
||||
} else {
|
||||
id2Values.put(element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
}
|
||||
}
|
||||
break;
|
||||
case VALUE:
|
||||
SchemaElement dimElement = semanticSchema.getElement(SchemaElementType.DIMENSION, element.getId());
|
||||
if (dimElement != null) {
|
||||
if (dim2Values.containsKey(element.getId())) {
|
||||
dim2Values.get(element.getId()).add(schemaMatch);
|
||||
} else {
|
||||
dim2Values.put(element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
}
|
||||
}
|
||||
break;
|
||||
case DIMENSION:
|
||||
parseInfo.getDimensions().add(element);
|
||||
break;
|
||||
case METRIC:
|
||||
parseInfo.getMetrics().add(element);
|
||||
break;
|
||||
case ENTITY:
|
||||
parseInfo.setEntity(element);
|
||||
break;
|
||||
default:
|
||||
}
|
||||
}
|
||||
addToFilters(id2Values, parseInfo, semanticSchema, SchemaElementType.ENTITY);
|
||||
addToFilters(dim2Values, parseInfo, semanticSchema, SchemaElementType.DIMENSION);
|
||||
}
|
||||
|
||||
private void addToFilters(Map<Long, List<SchemaElementMatch>> id2Values, SemanticParseInfo parseInfo,
|
||||
SemanticSchema semanticSchema, SchemaElementType entity) {
|
||||
if (id2Values == null || id2Values.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
for (Entry<Long, List<SchemaElementMatch>> entry : id2Values.entrySet()) {
|
||||
SchemaElement dimension = semanticSchema.getElement(entity, entry.getKey());
|
||||
|
||||
if (entry.getValue().size() == 1) {
|
||||
SchemaElementMatch schemaMatch = entry.getValue().get(0);
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
dimensionFilter.setValue(schemaMatch.getWord());
|
||||
dimensionFilter.setBizName(dimension.getBizName());
|
||||
dimensionFilter.setName(dimension.getName());
|
||||
dimensionFilter.setOperator(FilterOperatorEnum.EQUALS);
|
||||
dimensionFilter.setElementID(schemaMatch.getElement().getId());
|
||||
parseInfo.setEntity(semanticSchema.getElement(SchemaElementType.ENTITY, entry.getKey()));
|
||||
parseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
} else {
|
||||
QueryFilter dimensionFilter = new QueryFilter();
|
||||
List<String> vals = new ArrayList<>();
|
||||
entry.getValue().stream().forEach(i -> vals.add(i.getWord()));
|
||||
dimensionFilter.setValue(vals);
|
||||
dimensionFilter.setBizName(dimension.getBizName());
|
||||
dimensionFilter.setName(dimension.getName());
|
||||
dimensionFilter.setOperator(FilterOperatorEnum.IN);
|
||||
dimensionFilter.setElementID(entry.getKey());
|
||||
parseInfo.getDimensionFilters().add(dimensionFilter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addToValues(SemanticSchema semanticSchema, SchemaElementType entity,
|
||||
Map<Long, List<SchemaElementMatch>> id2Values, SchemaElementMatch schemaMatch) {
|
||||
SchemaElement element = schemaMatch.getElement();
|
||||
SchemaElement entityElement = semanticSchema.getElement(entity, element.getId());
|
||||
if (entityElement != null) {
|
||||
if (id2Values.containsKey(element.getId())) {
|
||||
id2Values.get(element.getId()).add(schemaMatch);
|
||||
} else {
|
||||
id2Values.put(element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticQueryReq buildSemanticQueryReq() {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
|
||||
if (parseInfo.getDataSetId() == null || StringUtils.isEmpty(queryMode)
|
||||
|| !QueryManager.containsRuleQuery(queryMode)) {
|
||||
// reach here some error may happen
|
||||
log.error("not find QueryMode");
|
||||
throw new RuntimeException("not find QueryMode");
|
||||
}
|
||||
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
return queryStructReq.convert(true);
|
||||
}
|
||||
|
||||
protected boolean isMultiStructQuery() {
|
||||
return false;
|
||||
}
|
||||
|
||||
public SemanticQueryReq multiStructExecute() {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
|
||||
if (parseInfo.getDataSetId() != null || StringUtils.isEmpty(queryMode)
|
||||
|| !QueryManager.containsRuleQuery(queryMode)) {
|
||||
// reach here some error may happen
|
||||
log.error("not find QueryMode");
|
||||
throw new RuntimeException("not find QueryMode");
|
||||
}
|
||||
|
||||
return convertQueryMultiStruct();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setParseInfo(SemanticParseInfo parseInfo) {
|
||||
this.parseInfo = parseInfo;
|
||||
}
|
||||
|
||||
public static List<RuleSemanticQuery> resolve(Long dataSetId, List<SchemaElementMatch> candidateElementMatches,
|
||||
QueryContext queryContext) {
|
||||
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
|
||||
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
|
||||
List<SchemaElementMatch> matches = semanticQuery.match(candidateElementMatches, queryContext);
|
||||
|
||||
if (matches.size() > 0) {
|
||||
RuleSemanticQuery query = QueryManager.createRuleQuery(semanticQuery.getQueryMode());
|
||||
query.getParseInfo().getElementMatches().addAll(matches);
|
||||
matchedQueries.add(query);
|
||||
}
|
||||
}
|
||||
return matchedQueries;
|
||||
}
|
||||
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||
}
|
||||
|
||||
protected QueryMultiStructReq convertQueryMultiStruct() {
|
||||
return QueryReqBuilder.buildMultiStructReq(parseInfo);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class DetailFilterQuery extends DetailListQuery {
|
||||
|
||||
public static final String QUERY_MODE = "DETAIL_LIST_FILTER";
|
||||
|
||||
public DetailFilterQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.tencent.supersonic.headless.chat.query.rule.detail;
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
@Component
|
||||
public class DetailIdQuery extends DetailListQuery {
|
||||
|
||||
public static final String QUERY_MODE = "DETAIL_ID";
|
||||
|
||||
public DetailIdQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user