(improvement)(Headless) Add 'topFields' to the map interface and return dataSet and field names in Chinese. (#892)

This commit is contained in:
lexluo09
2024-04-07 21:29:53 +08:00
committed by GitHub
parent 5f6e9ae194
commit 07b5eb47b6
11 changed files with 175 additions and 14 deletions

View File

@@ -18,6 +18,7 @@ import java.util.List;
public class SchemaElement implements Serializable {
private Long dataSet;
private String dataSetName;
private Long model;
private Long id;
private String name;

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.api.pojo;
import com.google.common.collect.Lists;
import org.apache.commons.collections4.CollectionUtils;
import java.util.HashMap;
import java.util.List;
@@ -30,4 +31,18 @@ public class SchemaMapInfo {
public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) {
dataSetElementMatches.put(dataSet, elementMatches);
}
public Map<Long, String> generateDataSetMap() {
Map<Long, String> dataSetIdToName = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
List<SchemaElementMatch> values = entry.getValue();
if (CollectionUtils.isNotEmpty(values)) {
SchemaElementMatch schemaElementMatch = values.get(0);
String dataSetName = schemaElementMatch.getElement().getDataSetName();
dataSetIdToName.put(entry.getKey(), dataSetName);
}
}
return dataSetIdToName;
}
}

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.headless.api.pojo.response;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
public class MapInfoResp {
private String queryText;
private Map<String, List<SchemaElementMatch>> mapFields;
private Map<String, List<SchemaElementMatch>> topFields;
}

View File

@@ -51,9 +51,7 @@ public abstract class BaseMapper implements SchemaMapper {
AtomicBoolean needAddNew = new AtomicBoolean(true);
schemaElementMatches.removeIf(
existElementMatch -> {
SchemaElement existElement = existElementMatch.getElement();
SchemaElement newElement = newElementMatch.getElement();
if (existElement.equals(newElement)) {
if (isEquals(existElementMatch, newElementMatch)) {
if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) {
return true;
} else {
@@ -68,8 +66,20 @@ public abstract class BaseMapper implements SchemaMapper {
}
}
private static boolean isEquals(SchemaElementMatch existElementMatch, SchemaElementMatch newElementMatch) {
SchemaElement existElement = existElementMatch.getElement();
SchemaElement newElement = newElementMatch.getElement();
if (!existElement.equals(newElement)) {
return false;
}
if (SchemaElementType.VALUE.equals(newElement.getType())) {
return existElementMatch.getWord().equalsIgnoreCase(newElementMatch.getWord());
}
return true;
}
public SchemaElement getSchemaElement(Long dataSetId, SchemaElementType elementType, Long elementID,
SemanticSchema semanticSchema) {
SemanticSchema semanticSchema) {
SchemaElement element = new SchemaElement();
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
if (Objects.isNull(dataSetSchema)) {

View File

@@ -6,6 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.DatabaseMapResult;
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
@@ -46,7 +47,7 @@ public class KeywordMapper extends BaseMapper {
}
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
List<S2Term> terms) {
List<S2Term> terms) {
if (CollectionUtils.isEmpty(mapResults)) {
return;
}
@@ -71,9 +72,6 @@ public class KeywordMapper extends BaseMapper {
if (element == null) {
continue;
}
if (element.getType().equals(SchemaElementType.VALUE)) {
element.setName(hanlpMapResult.getName());
}
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
@@ -100,7 +98,7 @@ public class KeywordMapper extends BaseMapper {
.element(schemaElement)
.word(schemaElement.getName())
.detectWord(match.getDetectWord())
.frequency(10000L)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
.build();
log.info("add to schema, elementMatch {}", schemaElementMatch);

View File

@@ -117,6 +117,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
});
SchemaMapInfo mapInfo = queryCtx.getMapInfo();
mapResp.setMapInfo(mapInfo);
mapResp.setQueryText(queryReq.getQueryText());
return mapResp;
}

View File

@@ -12,6 +12,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
@@ -19,6 +20,7 @@ import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
@@ -30,6 +32,7 @@ public class DataSetSchemaBuilder {
dataSetSchema.setQueryConfig(resp.getQueryConfig());
SchemaElement dataSet = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetName(resp.getName())
.id(resp.getId())
.name(resp.getName())
.bizName(resp.getBizName())
@@ -66,6 +69,7 @@ public class DataSetSchemaBuilder {
if (metric.getIsTag() == 1) {
SchemaElement tagToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetName(resp.getName())
.model(metric.getModelId())
.id(metric.getId())
.name(metric.getName())
@@ -96,6 +100,7 @@ public class DataSetSchemaBuilder {
if (dim.getIsTag() == 1) {
SchemaElement tagToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetName(resp.getName())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
@@ -143,6 +148,7 @@ public class DataSetSchemaBuilder {
}
SchemaElement dimToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetName(resp.getName())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
@@ -174,6 +180,7 @@ public class DataSetSchemaBuilder {
}
SchemaElement dimValueToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetName(resp.getName())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
@@ -195,6 +202,7 @@ public class DataSetSchemaBuilder {
SchemaElement metricToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetName(resp.getName())
.model(metric.getModelId())
.id(metric.getId())
.name(metric.getName())