mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 13:47:09 +00:00
(improvement)(Headless) Abstracted tags from dimensions and metrics. (#828)
This commit is contained in:
@@ -18,7 +18,7 @@ public class KnowledgeService {
|
||||
public void updateSemanticKnowledge(List<DictWord> natures) {
|
||||
|
||||
List<DictWord> prefixes = natures.stream()
|
||||
.filter(entry -> !entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getTypeWithSpilt()))
|
||||
.filter(entry -> !entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getType()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
for (DictWord nature : prefixes) {
|
||||
@@ -26,7 +26,7 @@ public class KnowledgeService {
|
||||
}
|
||||
|
||||
List<DictWord> suffixes = natures.stream()
|
||||
.filter(entry -> entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getTypeWithSpilt()))
|
||||
.filter(entry -> entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getType()))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
SearchService.loadSuffix(suffixes);
|
||||
|
||||
@@ -88,7 +88,7 @@ public class SearchService {
|
||||
entry -> {
|
||||
String name = entry.getKey().replace("#", " ");
|
||||
List<String> natures = entry.getValue().stream()
|
||||
.map(nature -> nature.replaceAll(DictWordType.SUFFIX.getTypeWithSpilt(), ""))
|
||||
.map(nature -> nature.replaceAll(DictWordType.SUFFIX.getType(), ""))
|
||||
.collect(Collectors.toList());
|
||||
name = StringUtils.reverse(name);
|
||||
return new HanlpMapResult(name, natures, key);
|
||||
@@ -169,8 +169,8 @@ public class SearchService {
|
||||
if (Objects.nonNull(natures) && natures.length > 0) {
|
||||
trie.put(dictWord.getWord(), getValue(natures));
|
||||
}
|
||||
if (dictWord.getNature().contains(DictWordType.METRIC.getTypeWithSpilt()) || dictWord.getNature()
|
||||
.contains(DictWordType.DIMENSION.getTypeWithSpilt())) {
|
||||
if (dictWord.getNature().contains(DictWordType.METRIC.getType()) || dictWord.getNature()
|
||||
.contains(DictWordType.DIMENSION.getType())) {
|
||||
suffixTrie.remove(dictWord.getWord());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,10 +31,10 @@ public class DimensionWordBuilder extends BaseWordWithAliasBuilder {
|
||||
dictWord.setWord(word);
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.DIMENSION.getTypeWithSpilt();
|
||||
+ DictWordType.DIMENSION.getType();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getTypeWithSpilt() + DictWordType.DIMENSION.getTypeWithSpilt();
|
||||
+ DictWordType.SUFFIX.getType() + DictWordType.DIMENSION.getType();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
|
||||
@@ -29,7 +29,7 @@ public class EntityWordBuilder extends BaseWordWithAliasBuilder {
|
||||
@Override
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
String nature = DictWordType.NATURE_SPILT + schemaElement.getModel()
|
||||
+ DictWordType.NATURE_SPILT + schemaElement.getId() + DictWordType.ENTITY.getTypeWithSpilt();
|
||||
+ DictWordType.NATURE_SPILT + schemaElement.getId() + DictWordType.ENTITY.getType();
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY * 2, nature));
|
||||
|
||||
@@ -31,10 +31,10 @@ public class MetricWordBuilder extends BaseWordWithAliasBuilder {
|
||||
dictWord.setWord(word);
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.METRIC.getTypeWithSpilt();
|
||||
+ DictWordType.METRIC.getType();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getTypeWithSpilt() + DictWordType.METRIC.getTypeWithSpilt();
|
||||
+ DictWordType.SUFFIX.getType() + DictWordType.METRIC.getType();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
package com.tencent.supersonic.headless.core.chat.knowledge.builder;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.DictWord;
|
||||
|
||||
import java.util.List;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class TagWordBuilder extends BaseWordWithAliasBuilder {
|
||||
|
||||
@Override
|
||||
public List<DictWord> doGet(String word, SchemaElement schemaElement) {
|
||||
List<DictWord> result = Lists.newArrayList();
|
||||
result.add(getOneWordNature(word, schemaElement, false));
|
||||
result.addAll(getOneWordNatureAlias(schemaElement, false));
|
||||
String reverseWord = StringUtils.reverse(word);
|
||||
if (!word.equalsIgnoreCase(reverseWord)) {
|
||||
result.add(getOneWordNature(reverseWord, schemaElement, true));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(word);
|
||||
Long modelId = schemaElement.getModel();
|
||||
String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.TAG.getTypeWithSpilt();
|
||||
if (isSuffix) {
|
||||
nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + schemaElement.getId()
|
||||
+ DictWordType.SUFFIX.getTypeWithSpilt() + DictWordType.TAG.getTypeWithSpilt();
|
||||
}
|
||||
dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature));
|
||||
return dictWord;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -18,7 +18,6 @@ public class WordBuilderFactory {
|
||||
wordNatures.put(DictWordType.DATASET, new ModelWordBuilder());
|
||||
wordNatures.put(DictWordType.ENTITY, new EntityWordBuilder());
|
||||
wordNatures.put(DictWordType.VALUE, new ValueWordBuilder());
|
||||
wordNatures.put(DictWordType.TAG, new TagWordBuilder());
|
||||
}
|
||||
|
||||
public static BaseWordBuilder get(DictWordType strategyType) {
|
||||
|
||||
@@ -46,12 +46,6 @@ public class NatureHelper {
|
||||
case VALUE:
|
||||
result = SchemaElementType.VALUE;
|
||||
break;
|
||||
case TAG:
|
||||
result = SchemaElementType.TAG;
|
||||
break;
|
||||
case TAG_VALUE:
|
||||
result = SchemaElementType.TAG_VALUE;
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -60,7 +54,7 @@ public class NatureHelper {
|
||||
|
||||
private static boolean isDataSetOrEntity(S2Term term, Integer model) {
|
||||
return (DictWordType.NATURE_SPILT + model).equals(term.nature.toString()) || term.nature.toString()
|
||||
.endsWith(DictWordType.ENTITY.getTypeWithSpilt());
|
||||
.endsWith(DictWordType.ENTITY.getType());
|
||||
}
|
||||
|
||||
public static Integer getDataSetByNature(Nature nature) {
|
||||
@@ -134,8 +128,8 @@ public class NatureHelper {
|
||||
if (split.length <= 1) {
|
||||
return false;
|
||||
}
|
||||
return !nature.endsWith(DictWordType.METRIC.getTypeWithSpilt()) && !nature.endsWith(
|
||||
DictWordType.DIMENSION.getTypeWithSpilt())
|
||||
return !nature.endsWith(DictWordType.METRIC.getType()) && !nature.endsWith(
|
||||
DictWordType.DIMENSION.getType())
|
||||
&& StringUtils.isNumeric(split[1]);
|
||||
}
|
||||
|
||||
@@ -158,12 +152,12 @@ public class NatureHelper {
|
||||
|
||||
private static long getDimensionCount(List<S2Term> terms) {
|
||||
return terms.stream().filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT) && term.nature.toString()
|
||||
.endsWith(DictWordType.DIMENSION.getTypeWithSpilt())).count();
|
||||
.endsWith(DictWordType.DIMENSION.getType())).count();
|
||||
}
|
||||
|
||||
private static long getMetricCount(List<S2Term> terms) {
|
||||
return terms.stream().filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT) && term.nature.toString()
|
||||
.endsWith(DictWordType.METRIC.getTypeWithSpilt())).count();
|
||||
.endsWith(DictWordType.METRIC.getType())).count();
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -33,10 +33,8 @@ public class EntityMapper extends BaseMapper {
|
||||
continue;
|
||||
}
|
||||
List<SchemaElementMatch> valueSchemaElements = schemaElementMatchList.stream()
|
||||
.filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.TAG_VALUE.equals(schemaElementMatch.getElement().getType()
|
||||
))
|
||||
.filter(schemaElementMatch -> SchemaElementType.VALUE.equals(
|
||||
schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElements) {
|
||||
if (!entity.getId().equals(schemaElementMatch.getElement().getId())) {
|
||||
|
||||
@@ -71,8 +71,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
if (element == null) {
|
||||
continue;
|
||||
}
|
||||
if (element.getType().equals(SchemaElementType.VALUE) || element.getType()
|
||||
.equals(SchemaElementType.TAG_VALUE)) {
|
||||
if (element.getType().equals(SchemaElementType.VALUE)) {
|
||||
element.setName(hanlpMapResult.getName());
|
||||
}
|
||||
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||
|
||||
@@ -19,12 +19,12 @@ import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
public class QueryFilterMapper implements SchemaMapper {
|
||||
public class QueryFilterMapper extends BaseMapper {
|
||||
|
||||
private double similarity = 1.0;
|
||||
|
||||
@Override
|
||||
public void map(QueryContext queryContext) {
|
||||
public void doMap(QueryContext queryContext) {
|
||||
Set<Long> dataSetIds = queryContext.getDataSetIds();
|
||||
if (CollectionUtils.isEmpty(dataSetIds)) {
|
||||
return;
|
||||
|
||||
@@ -65,7 +65,7 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
||||
// remove entity name where search
|
||||
hanlpMapResults = hanlpMapResults.stream().filter(entry -> {
|
||||
List<String> natures = entry.getNatures().stream()
|
||||
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getTypeWithSpilt()))
|
||||
.filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType()))
|
||||
.collect(Collectors.toList());
|
||||
if (CollectionUtils.isEmpty(natures)) {
|
||||
return false;
|
||||
|
||||
@@ -1,13 +1,25 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlSelectHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.core.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
/**
|
||||
* QueryTypeParser resolves query type as either METRIC or TAG, or ID.
|
||||
@@ -25,9 +37,50 @@ public class QueryTypeParser implements SemanticParser {
|
||||
// 1.init S2SQL
|
||||
semanticQuery.initS2Sql(queryContext.getSemanticSchema(), user);
|
||||
// 2.set queryType
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
parseInfo.setQueryType(queryContext.getQueryType(parseInfo.getDataSetId()));
|
||||
QueryType queryType = getQueryType(queryContext, semanticQuery);
|
||||
semanticQuery.getParseInfo().setQueryType(queryType);
|
||||
}
|
||||
}
|
||||
|
||||
private QueryType getQueryType(QueryContext queryContext, SemanticQuery semanticQuery) {
|
||||
SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
if (Objects.isNull(sqlInfo) || StringUtils.isBlank(sqlInfo.getS2SQL())) {
|
||||
return QueryType.ID;
|
||||
}
|
||||
//1. entity queryType
|
||||
Long dataSetId = parseInfo.getDataSetId();
|
||||
SemanticSchema semanticSchema = queryContext.getSemanticSchema();
|
||||
if (semanticQuery instanceof RuleSemanticQuery || semanticQuery instanceof LLMSqlQuery) {
|
||||
//If all the fields in the SELECT statement are of tag type.
|
||||
List<String> whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getS2SQL())
|
||||
.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
if (CollectionUtils.isNotEmpty(whereFields)) {
|
||||
Set<String> ids = semanticSchema.getEntities(dataSetId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFields::contains)) {
|
||||
return QueryType.ID;
|
||||
}
|
||||
Set<String> tags = semanticSchema.getTags(dataSetId).stream().map(SchemaElement::getName)
|
||||
.collect(Collectors.toSet());
|
||||
if (CollectionUtils.isNotEmpty(tags) && tags.containsAll(whereFields)) {
|
||||
return QueryType.TAG;
|
||||
}
|
||||
}
|
||||
}
|
||||
//2. metric queryType
|
||||
List<String> selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getS2SQL());
|
||||
List<SchemaElement> metrics = semanticSchema.getMetrics(dataSetId);
|
||||
if (CollectionUtils.isNotEmpty(metrics)) {
|
||||
Set<String> metricNameSet = metrics.stream().map(SchemaElement::getName).collect(Collectors.toSet());
|
||||
boolean containMetric = selectFields.stream().anyMatch(metricNameSet::contains);
|
||||
if (containMetric) {
|
||||
return QueryType.METRIC;
|
||||
}
|
||||
}
|
||||
return QueryType.ID;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
@@ -17,13 +16,6 @@ import com.tencent.supersonic.headless.core.config.OptimizationConfig;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.headless.core.utils.S2SqlDateHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashSet;
|
||||
@@ -32,10 +24,17 @@ import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class LLMRequestService {
|
||||
|
||||
@Autowired
|
||||
private LLMParserConfig llmParserConfig;
|
||||
@Autowired
|
||||
@@ -62,7 +61,7 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
|
||||
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
|
||||
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
|
||||
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
|
||||
String queryText = queryCtx.getQueryText();
|
||||
|
||||
@@ -154,8 +153,7 @@ public class LLMRequestService {
|
||||
.filter(elementMatch -> !elementMatch.isInherited())
|
||||
.filter(schemaElementMatch -> {
|
||||
SchemaElementType type = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.VALUE.equals(type) || SchemaElementType.TAG_VALUE.equals(type)
|
||||
|| SchemaElementType.ID.equals(type);
|
||||
return SchemaElementType.VALUE.equals(type) || SchemaElementType.ID.equals(type);
|
||||
})
|
||||
.map(elementMatch -> {
|
||||
ElementValue elementValue = new ElementValue();
|
||||
@@ -169,9 +167,6 @@ public class LLMRequestService {
|
||||
protected Map<Long, String> getItemIdToName(QueryContext queryCtx, Long dataSetId) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
List<SchemaElement> elements = semanticSchema.getDimensions(dataSetId);
|
||||
if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) {
|
||||
elements = semanticSchema.getTags(dataSetId);
|
||||
}
|
||||
return elements.stream()
|
||||
.collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2));
|
||||
}
|
||||
@@ -179,27 +174,18 @@ public class LLMRequestService {
|
||||
private Set<String> getTopNFieldNames(QueryContext queryCtx, Long dataSetId, LLMParserConfig llmParserConfig) {
|
||||
SemanticSchema semanticSchema = queryCtx.getSemanticSchema();
|
||||
Set<String> results = new HashSet<>();
|
||||
if (QueryType.TAG.equals(queryCtx.getQueryType(dataSetId))) {
|
||||
Set<String> tags = semanticSchema.getTags(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
results.addAll(tags);
|
||||
} else {
|
||||
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
results.addAll(dimensions);
|
||||
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getMetricTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
results.addAll(metrics);
|
||||
}
|
||||
Set<String> dimensions = semanticSchema.getDimensions(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getDimensionTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
results.addAll(dimensions);
|
||||
Set<String> metrics = semanticSchema.getMetrics(dataSetId).stream()
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(llmParserConfig.getMetricTopN())
|
||||
.map(entry -> entry.getName())
|
||||
.collect(Collectors.toSet());
|
||||
results.addAll(metrics);
|
||||
return results;
|
||||
}
|
||||
|
||||
@@ -214,15 +200,12 @@ public class LLMRequestService {
|
||||
SchemaElementType elementType = schemaElementMatch.getElement().getType();
|
||||
return SchemaElementType.METRIC.equals(elementType)
|
||||
|| SchemaElementType.DIMENSION.equals(elementType)
|
||||
|| SchemaElementType.VALUE.equals(elementType)
|
||||
|| SchemaElementType.TAG.equals(elementType)
|
||||
|| SchemaElementType.TAG_VALUE.equals(elementType);
|
||||
|| SchemaElementType.VALUE.equals(elementType);
|
||||
})
|
||||
.map(schemaElementMatch -> {
|
||||
SchemaElement element = schemaElementMatch.getElement();
|
||||
SchemaElementType elementType = element.getType();
|
||||
if (SchemaElementType.VALUE.equals(elementType) || SchemaElementType.TAG_VALUE.equals(
|
||||
elementType)) {
|
||||
if (SchemaElementType.VALUE.equals(elementType)) {
|
||||
return itemIdToName.get(element.getId());
|
||||
}
|
||||
return schemaElementMatch.getWord();
|
||||
|
||||
@@ -38,7 +38,6 @@ public class ContextInheritParser implements SemanticParser {
|
||||
new AbstractMap.SimpleEntry<>(
|
||||
SchemaElementType.VALUE, Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY, Arrays.asList(SchemaElementType.ENTITY)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.TAG, Arrays.asList(SchemaElementType.TAG)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.DATASET, Arrays.asList(SchemaElementType.DATASET)),
|
||||
new AbstractMap.SimpleEntry<>(SchemaElementType.ID, Arrays.asList(SchemaElementType.ID))
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
package com.tencent.supersonic.headless.core.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.core.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
@@ -35,24 +32,10 @@ public class RuleSqlParser implements SemanticParser {
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, queryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(queryContext, chatContext);
|
||||
SemanticParseInfo parseInfo = query.getParseInfo();
|
||||
QueryType queryType = queryContext.getQueryType(parseInfo.getDataSetId());
|
||||
if (isRightQuery(parseInfo, queryType)) {
|
||||
queryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
queryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
}
|
||||
|
||||
auxiliaryParsers.stream().forEach(p -> p.parse(queryContext, chatContext));
|
||||
}
|
||||
|
||||
private boolean isRightQuery(SemanticParseInfo parseInfo, QueryType queryType) {
|
||||
if (QueryType.TAG.equals(queryType) && QueryManager.isTagQuery(parseInfo.getQueryMode())) {
|
||||
return true;
|
||||
}
|
||||
if (QueryType.METRIC.equals(queryType) && QueryManager.isMetricQuery(parseInfo.getQueryMode())) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.core.chat.query.rule;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
@@ -18,10 +17,6 @@ import com.tencent.supersonic.headless.core.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.utils.QueryReqBuilder;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
@@ -30,6 +25,9 @@ import java.util.Map;
|
||||
import java.util.Map.Entry;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.ToString;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@Slf4j
|
||||
@ToString
|
||||
@@ -42,7 +40,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
|
||||
public List<SchemaElementMatch> match(List<SchemaElementMatch> candidateElementMatches,
|
||||
QueryContext queryCtx) {
|
||||
QueryContext queryCtx) {
|
||||
return queryMatcher.match(candidateElementMatches);
|
||||
}
|
||||
|
||||
@@ -101,22 +99,31 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId));
|
||||
Map<Long, List<SchemaElementMatch>> dim2Values = new HashMap<>();
|
||||
Map<Long, List<SchemaElementMatch>> id2Values = new HashMap<>();
|
||||
Map<Long, List<SchemaElementMatch>> tag2Values = new HashMap<>();
|
||||
|
||||
for (SchemaElementMatch schemaMatch : parseInfo.getElementMatches()) {
|
||||
SchemaElement element = schemaMatch.getElement();
|
||||
element.setOrder(1 - schemaMatch.getSimilarity());
|
||||
switch (element.getType()) {
|
||||
case ID:
|
||||
addToValues(semanticSchema, SchemaElementType.ENTITY, id2Values, schemaMatch);
|
||||
break;
|
||||
case TAG_VALUE:
|
||||
addToValues(semanticSchema, SchemaElementType.TAG, tag2Values, schemaMatch);
|
||||
SchemaElement entityElement = semanticSchema.getElement(SchemaElementType.ENTITY, element.getId());
|
||||
if (entityElement != null) {
|
||||
if (id2Values.containsKey(element.getId())) {
|
||||
id2Values.get(element.getId()).add(schemaMatch);
|
||||
} else {
|
||||
id2Values.put(element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
}
|
||||
}
|
||||
break;
|
||||
case VALUE:
|
||||
addToValues(semanticSchema, SchemaElementType.DIMENSION, dim2Values, schemaMatch);
|
||||
SchemaElement dimElement = semanticSchema.getElement(SchemaElementType.DIMENSION, element.getId());
|
||||
if (dimElement != null) {
|
||||
if (dim2Values.containsKey(element.getId())) {
|
||||
dim2Values.get(element.getId()).add(schemaMatch);
|
||||
} else {
|
||||
dim2Values.put(element.getId(), new ArrayList<>(Arrays.asList(schemaMatch)));
|
||||
}
|
||||
}
|
||||
break;
|
||||
case TAG:
|
||||
case DIMENSION:
|
||||
parseInfo.getDimensions().add(element);
|
||||
break;
|
||||
@@ -129,10 +136,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
addToFilters(id2Values, parseInfo, semanticSchema, SchemaElementType.ENTITY);
|
||||
addToFilters(dim2Values, parseInfo, semanticSchema, SchemaElementType.DIMENSION);
|
||||
addToFilters(tag2Values, parseInfo, semanticSchema, SchemaElementType.TAG);
|
||||
}
|
||||
|
||||
private void addToFilters(Map<Long, List<SchemaElementMatch>> id2Values, SemanticParseInfo parseInfo,
|
||||
@@ -220,8 +225,6 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
public static List<RuleSemanticQuery> resolve(Long dataSetId, List<SchemaElementMatch> candidateElementMatches,
|
||||
QueryContext queryContext) {
|
||||
List<RuleSemanticQuery> matchedQueries = new ArrayList<>();
|
||||
candidateElementMatches = filterByQueryType(dataSetId, candidateElementMatches, queryContext);
|
||||
|
||||
for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) {
|
||||
List<SchemaElementMatch> matches = semanticQuery.match(candidateElementMatches, queryContext);
|
||||
|
||||
@@ -231,30 +234,9 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
matchedQueries.add(query);
|
||||
}
|
||||
}
|
||||
|
||||
return matchedQueries;
|
||||
}
|
||||
|
||||
private static List<SchemaElementMatch> filterByQueryType(Long dataSetId,
|
||||
List<SchemaElementMatch> candidateElementMatches, QueryContext queryContext) {
|
||||
QueryType queryType = queryContext.getQueryType(dataSetId);
|
||||
if (QueryType.TAG.equals(queryType)) {
|
||||
candidateElementMatches = candidateElementMatches.stream()
|
||||
.filter(elementMatch -> !(SchemaElementType.METRIC.equals(elementMatch.getElement().getType())
|
||||
|| SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType())
|
||||
|| SchemaElementType.VALUE.equals(elementMatch.getElement().getType()))
|
||||
)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
if (QueryType.METRIC.equals(queryType)) {
|
||||
candidateElementMatches = candidateElementMatches.stream()
|
||||
.filter(elementMatch -> !(SchemaElementType.TAG.equals(elementMatch.getElement().getType())
|
||||
|| SchemaElementType.TAG_VALUE.equals(elementMatch.getElement().getType())))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
return candidateElementMatches;
|
||||
}
|
||||
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
return QueryReqBuilder.buildStructReq(parseInfo);
|
||||
}
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
package com.tencent.supersonic.headless.core.chat.query.rule.tag;
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TAG;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.ID;
|
||||
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
public class TagDetailQuery extends TagSemanticQuery {
|
||||
@@ -14,8 +13,7 @@ public class TagDetailQuery extends TagSemanticQuery {
|
||||
|
||||
public TagDetailQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(TAG, REQUIRED, AT_LEAST, 1)
|
||||
.addOption(ID, REQUIRED, AT_LEAST, 1);
|
||||
queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
package com.tencent.supersonic.headless.core.chat.query.rule.tag;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.VALUE;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TAG;
|
||||
import static com.tencent.supersonic.headless.api.pojo.SchemaElementType.TAG_VALUE;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.OptionType.REQUIRED;
|
||||
import static com.tencent.supersonic.headless.core.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class TagFilterQuery extends TagListQuery {
|
||||
@@ -17,8 +15,7 @@ public class TagFilterQuery extends TagListQuery {
|
||||
|
||||
public TagFilterQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(TAG, OPTIONAL, AT_LEAST, 0);
|
||||
queryMatcher.addOption(TAG_VALUE, REQUIRED, AT_LEAST, 1);
|
||||
queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
Reference in New Issue
Block a user