(improvement)(Headless) Add 'topFields' to the map interface and return dataSet and field names in Chinese. (#892)

This commit is contained in:
lexluo09
2024-04-07 21:29:53 +08:00
committed by GitHub
parent 5f6e9ae194
commit 07b5eb47b6
11 changed files with 175 additions and 14 deletions

View File

@@ -10,6 +10,7 @@ public class ChatParseReq {
private String queryText;
private Integer chatId;
private Integer agentId;
private Integer topN = 10;
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
@@ -17,7 +17,7 @@ public interface ChatService {
List<SearchResult> search(ChatParseReq chatParseReq);
MapResp performMapping(ChatParseReq chatParseReq);
MapInfoResp performMapping(ChatParseReq chatParseReq);
ParseResp performParsing(ChatParseReq chatParseReq);

View File

@@ -17,6 +17,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.chat.server.util.MapInfoConverter;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.chat.server.util.SimilarQueryManager;
import com.tencent.supersonic.common.util.BeanMapper;
@@ -25,6 +26,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
@@ -34,6 +36,7 @@ import com.tencent.supersonic.headless.server.service.SearchService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.List;
@@ -62,7 +65,7 @@ public class ChatServiceImpl implements ChatService {
}
@Override
public MapResp performMapping(ChatParseReq chatParseReq) {
public MapInfoResp performMapping(ChatParseReq chatParseReq) {
return getMapResp(chatParseReq);
}
@@ -110,14 +113,15 @@ public class ChatServiceImpl implements ChatService {
return chatParseContext;
}
private MapResp getMapResp(ChatParseReq chatParseReq) {
private MapInfoResp getMapResp(ChatParseReq chatParseReq) {
ChatParseContext chatParseContext = new ChatParseContext();
BeanMapper.mapper(chatParseReq, chatParseContext);
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
chatParseContext.setAgent(agent);
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
return chatQueryService.performMapping(queryReq);
MapResp mapResp = chatQueryService.performMapping(queryReq);
return MapInfoConverter.convert(mapResp, chatParseReq.getTopN());
}
private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {

View File

@@ -0,0 +1,105 @@
package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.BeanUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
public class MapInfoConverter {
public static MapInfoResp convert(MapResp mapResp, Integer topN) {
MapInfoResp mapInfoResp = new MapInfoResp();
if (Objects.isNull(mapResp)) {
return mapInfoResp;
}
BeanUtils.copyProperties(mapResp, mapInfoResp);
Map<Long, String> dataSetMap = mapResp.getMapInfo().generateDataSetMap();
mapInfoResp.setMapFields(getMapFields(mapResp.getMapInfo(), dataSetMap));
mapInfoResp.setTopFields(getTopFields(topN, mapResp.getMapInfo(), dataSetMap));
return mapInfoResp;
}
private static Map<String, List<SchemaElementMatch>> getMapFields(SchemaMapInfo mapInfo,
Map<Long, String> dataSetMap) {
Map<String, List<SchemaElementMatch>> result = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : mapInfo.getDataSetElementMatches().entrySet()) {
List<SchemaElementMatch> values = entry.getValue();
if (CollectionUtils.isNotEmpty(values) && dataSetMap.containsKey(entry.getKey())) {
result.put(dataSetMap.get(entry.getKey()), values);
}
}
return result;
}
private static Map<String, List<SchemaElementMatch>> getTopFields(Integer topN,
SchemaMapInfo mapInfo,
Map<Long, String> dataSetMap) {
Set<Long> dataSetIds = mapInfo.getDataSetElementMatches().keySet();
Map<String, List<SchemaElementMatch>> result = new HashMap<>();
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
for (Long dataSetId : dataSetIds) {
String dataSetName = dataSetMap.get(dataSetId);
//topN dimensions
Set<SchemaElementMatch> dimensions = semanticSchema.getDimensions(dataSetId)
.stream().sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(topN - 1).map(mergeFunction()).collect(Collectors.toSet());
SchemaElementMatch timeDimensionMatch = getTimeDimension(dataSetId, dataSetName);
dimensions.add(timeDimensionMatch);
//topN metrics
Set<SchemaElementMatch> metrics = semanticSchema.getMetrics(dataSetId)
.stream().sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(topN).map(mergeFunction()).collect(Collectors.toSet());
dimensions.addAll(metrics);
result.put(dataSetName, new ArrayList<>(dimensions));
}
return result;
}
/***
* get time dimension SchemaElementMatch
* @param dataSetId
* @param dataSetName
* @return
*/
private static SchemaElementMatch getTimeDimension(Long dataSetId, String dataSetName) {
SchemaElement element = SchemaElement.builder().dataSet(dataSetId).dataSetName(dataSetName)
.type(SchemaElementType.DIMENSION).bizName(TimeDimensionEnum.DAY.getName()).build();
SchemaElementMatch timeDimensionMatch = SchemaElementMatch.builder().element(element)
.detectWord(TimeDimensionEnum.DAY.getChName()).word(TimeDimensionEnum.DAY.getChName())
.similarity(1L).frequency(BaseWordBuilder.DEFAULT_FREQUENCY).build();
return timeDimensionMatch;
}
private static Function<SchemaElement, SchemaElementMatch> mergeFunction() {
return schemaElement -> SchemaElementMatch.builder().element(schemaElement)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY).word(schemaElement.getName()).similarity(1)
.detectWord(schemaElement.getName()).build();
}
}

View File

@@ -18,6 +18,7 @@ import java.util.List;
public class SchemaElement implements Serializable {
private Long dataSet;
private String dataSetName;
private Long model;
private Long id;
private String name;

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.api.pojo;
import com.google.common.collect.Lists;
import org.apache.commons.collections4.CollectionUtils;
import java.util.HashMap;
import java.util.List;
@@ -30,4 +31,18 @@ public class SchemaMapInfo {
public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) {
dataSetElementMatches.put(dataSet, elementMatches);
}
public Map<Long, String> generateDataSetMap() {
Map<Long, String> dataSetIdToName = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
List<SchemaElementMatch> values = entry.getValue();
if (CollectionUtils.isNotEmpty(values)) {
SchemaElementMatch schemaElementMatch = values.get(0);
String dataSetName = schemaElementMatch.getElement().getDataSetName();
dataSetIdToName.put(entry.getKey(), dataSetName);
}
}
return dataSetIdToName;
}
}

View File

@@ -0,0 +1,18 @@
package com.tencent.supersonic.headless.api.pojo.response;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
public class MapInfoResp {
private String queryText;
private Map<String, List<SchemaElementMatch>> mapFields;
private Map<String, List<SchemaElementMatch>> topFields;
}

View File

@@ -51,9 +51,7 @@ public abstract class BaseMapper implements SchemaMapper {
AtomicBoolean needAddNew = new AtomicBoolean(true);
schemaElementMatches.removeIf(
existElementMatch -> {
SchemaElement existElement = existElementMatch.getElement();
SchemaElement newElement = newElementMatch.getElement();
if (existElement.equals(newElement)) {
if (isEquals(existElementMatch, newElementMatch)) {
if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) {
return true;
} else {
@@ -68,6 +66,18 @@ public abstract class BaseMapper implements SchemaMapper {
}
}
private static boolean isEquals(SchemaElementMatch existElementMatch, SchemaElementMatch newElementMatch) {
SchemaElement existElement = existElementMatch.getElement();
SchemaElement newElement = newElementMatch.getElement();
if (!existElement.equals(newElement)) {
return false;
}
if (SchemaElementType.VALUE.equals(newElement.getType())) {
return existElementMatch.getWord().equalsIgnoreCase(newElementMatch.getWord());
}
return true;
}
public SchemaElement getSchemaElement(Long dataSetId, SchemaElementType elementType, Long elementID,
SemanticSchema semanticSchema) {
SchemaElement element = new SchemaElement();

View File

@@ -6,6 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import com.tencent.supersonic.headless.core.chat.knowledge.DatabaseMapResult;
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
@@ -71,9 +72,6 @@ public class KeywordMapper extends BaseMapper {
if (element == null) {
continue;
}
if (element.getType().equals(SchemaElementType.VALUE)) {
element.setName(hanlpMapResult.getName());
}
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(element)
@@ -100,7 +98,7 @@ public class KeywordMapper extends BaseMapper {
.element(schemaElement)
.word(schemaElement.getName())
.detectWord(match.getDetectWord())
.frequency(10000L)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
.build();
log.info("add to schema, elementMatch {}", schemaElementMatch);

View File

@@ -117,6 +117,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
});
SchemaMapInfo mapInfo = queryCtx.getMapInfo();
mapResp.setMapInfo(mapInfo);
mapResp.setQueryText(queryReq.getQueryText());
return mapResp;
}

View File

@@ -12,6 +12,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
@@ -19,6 +20,7 @@ import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.BeanUtils;
import org.springframework.util.CollectionUtils;
@@ -30,6 +32,7 @@ public class DataSetSchemaBuilder {
dataSetSchema.setQueryConfig(resp.getQueryConfig());
SchemaElement dataSet = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetName(resp.getName())
.id(resp.getId())
.name(resp.getName())
.bizName(resp.getBizName())
@@ -66,6 +69,7 @@ public class DataSetSchemaBuilder {
if (metric.getIsTag() == 1) {
SchemaElement tagToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetName(resp.getName())
.model(metric.getModelId())
.id(metric.getId())
.name(metric.getName())
@@ -96,6 +100,7 @@ public class DataSetSchemaBuilder {
if (dim.getIsTag() == 1) {
SchemaElement tagToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetName(resp.getName())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
@@ -143,6 +148,7 @@ public class DataSetSchemaBuilder {
}
SchemaElement dimToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetName(resp.getName())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
@@ -174,6 +180,7 @@ public class DataSetSchemaBuilder {
}
SchemaElement dimValueToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetName(resp.getName())
.model(dim.getModelId())
.id(dim.getId())
.name(dim.getName())
@@ -195,6 +202,7 @@ public class DataSetSchemaBuilder {
SchemaElement metricToAdd = SchemaElement.builder()
.dataSet(resp.getId())
.dataSetName(resp.getName())
.model(metric.getModelId())
.id(metric.getId())
.name(metric.getName())