From 07b5eb47b641cd029ddd2436bd1f5b307b600366 Mon Sep 17 00:00:00 2001 From: lexluo09 <39718951+lexluo09@users.noreply.github.com> Date: Sun, 7 Apr 2024 21:29:53 +0800 Subject: [PATCH] (improvement)(Headless) Add 'topFields' to the map interface and return dataSet and field names in Chinese. (#892) --- .../chat/api/pojo/request/ChatParseReq.java | 1 + .../chat/server/service/ChatService.java | 4 +- .../server/service/impl/ChatServiceImpl.java | 10 +- .../chat/server/util/MapInfoConverter.java | 105 ++++++++++++++++++ .../headless/api/pojo/SchemaElement.java | 1 + .../headless/api/pojo/SchemaMapInfo.java | 15 +++ .../api/pojo/response/MapInfoResp.java | 18 +++ .../headless/core/chat/mapper/BaseMapper.java | 18 ++- .../core/chat/mapper/KeywordMapper.java | 8 +- .../service/impl/ChatQueryServiceImpl.java | 1 + .../server/utils/DataSetSchemaBuilder.java | 8 ++ 11 files changed, 175 insertions(+), 14 deletions(-) create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/util/MapInfoConverter.java create mode 100644 headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/MapInfoResp.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java index e35613fc1..8ddc57c88 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java @@ -10,6 +10,7 @@ public class ChatParseReq { private String queryText; private Integer chatId; private Integer agentId; + private Integer topN = 10; private User user; private QueryFilters queryFilters; private boolean saveAnswer = true; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatService.java index d1f56e199..335e77aa2 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatService.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatService.java @@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; -import com.tencent.supersonic.headless.api.pojo.response.MapResp; +import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.SearchResult; @@ -17,7 +17,7 @@ public interface ChatService { List search(ChatParseReq chatParseReq); - MapResp performMapping(ChatParseReq chatParseReq); + MapInfoResp performMapping(ChatParseReq chatParseReq); ParseResp performParsing(ChatParseReq chatParseReq); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java index bb4c46b97..f131c5c17 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java @@ -17,6 +17,7 @@ import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.service.ChatService; import com.tencent.supersonic.chat.server.util.ComponentFactory; +import com.tencent.supersonic.chat.server.util.MapInfoConverter; import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.chat.server.util.SimilarQueryManager; import com.tencent.supersonic.common.util.BeanMapper; @@ -25,6 +26,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq; +import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp; import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; @@ -34,6 +36,7 @@ import com.tencent.supersonic.headless.server.service.SearchService; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; + import java.util.List; @@ -62,7 +65,7 @@ public class ChatServiceImpl implements ChatService { } @Override - public MapResp performMapping(ChatParseReq chatParseReq) { + public MapInfoResp performMapping(ChatParseReq chatParseReq) { return getMapResp(chatParseReq); } @@ -110,14 +113,15 @@ public class ChatServiceImpl implements ChatService { return chatParseContext; } - private MapResp getMapResp(ChatParseReq chatParseReq) { + private MapInfoResp getMapResp(ChatParseReq chatParseReq) { ChatParseContext chatParseContext = new ChatParseContext(); BeanMapper.mapper(chatParseReq, chatParseContext); AgentService agentService = ContextUtils.getBean(AgentService.class); Agent agent = agentService.getAgent(chatParseReq.getAgentId()); chatParseContext.setAgent(agent); QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); - return chatQueryService.performMapping(queryReq); + MapResp mapResp = chatQueryService.performMapping(queryReq); + return MapInfoConverter.convert(mapResp, chatParseReq.getTopN()); } private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/MapInfoConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/MapInfoConverter.java new file mode 100644 index 000000000..db697fbf3 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/MapInfoConverter.java @@ -0,0 +1,105 @@ +package com.tencent.supersonic.chat.server.util; + +import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; +import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.headless.api.pojo.SchemaElementType; +import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.headless.api.pojo.SemanticSchema; +import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp; +import com.tencent.supersonic.headless.api.pojo.response.MapResp; +import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder; +import com.tencent.supersonic.headless.server.service.impl.SemanticService; +import org.apache.commons.collections4.CollectionUtils; +import org.springframework.beans.BeanUtils; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class MapInfoConverter { + + public static MapInfoResp convert(MapResp mapResp, Integer topN) { + MapInfoResp mapInfoResp = new MapInfoResp(); + if (Objects.isNull(mapResp)) { + return mapInfoResp; + } + BeanUtils.copyProperties(mapResp, mapInfoResp); + Map dataSetMap = mapResp.getMapInfo().generateDataSetMap(); + mapInfoResp.setMapFields(getMapFields(mapResp.getMapInfo(), dataSetMap)); + mapInfoResp.setTopFields(getTopFields(topN, mapResp.getMapInfo(), dataSetMap)); + return mapInfoResp; + } + + private static Map> getMapFields(SchemaMapInfo mapInfo, + Map dataSetMap) { + Map> result = new HashMap<>(); + for (Map.Entry> entry : mapInfo.getDataSetElementMatches().entrySet()) { + List values = entry.getValue(); + if (CollectionUtils.isNotEmpty(values) && dataSetMap.containsKey(entry.getKey())) { + result.put(dataSetMap.get(entry.getKey()), values); + } + } + return result; + } + + private static Map> getTopFields(Integer topN, + SchemaMapInfo mapInfo, + Map dataSetMap) { + Set dataSetIds = mapInfo.getDataSetElementMatches().keySet(); + Map> result = new HashMap<>(); + + SemanticService semanticService = ContextUtils.getBean(SemanticService.class); + SemanticSchema semanticSchema = semanticService.getSemanticSchema(); + for (Long dataSetId : dataSetIds) { + String dataSetName = dataSetMap.get(dataSetId); + + //topN dimensions + Set dimensions = semanticSchema.getDimensions(dataSetId) + .stream().sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(topN - 1).map(mergeFunction()).collect(Collectors.toSet()); + + SchemaElementMatch timeDimensionMatch = getTimeDimension(dataSetId, dataSetName); + dimensions.add(timeDimensionMatch); + + //topN metrics + Set metrics = semanticSchema.getMetrics(dataSetId) + .stream().sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(topN).map(mergeFunction()).collect(Collectors.toSet()); + + dimensions.addAll(metrics); + result.put(dataSetName, new ArrayList<>(dimensions)); + } + return result; + } + + /*** + * get time dimension SchemaElementMatch + * @param dataSetId + * @param dataSetName + * @return + */ + private static SchemaElementMatch getTimeDimension(Long dataSetId, String dataSetName) { + SchemaElement element = SchemaElement.builder().dataSet(dataSetId).dataSetName(dataSetName) + .type(SchemaElementType.DIMENSION).bizName(TimeDimensionEnum.DAY.getName()).build(); + + SchemaElementMatch timeDimensionMatch = SchemaElementMatch.builder().element(element) + .detectWord(TimeDimensionEnum.DAY.getChName()).word(TimeDimensionEnum.DAY.getChName()) + .similarity(1L).frequency(BaseWordBuilder.DEFAULT_FREQUENCY).build(); + + return timeDimensionMatch; + } + + private static Function mergeFunction() { + return schemaElement -> SchemaElementMatch.builder().element(schemaElement) + .frequency(BaseWordBuilder.DEFAULT_FREQUENCY).word(schemaElement.getName()).similarity(1) + .detectWord(schemaElement.getName()).build(); + } +} diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java index 6fa343ecd..d680a3e22 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java @@ -18,6 +18,7 @@ import java.util.List; public class SchemaElement implements Serializable { private Long dataSet; + private String dataSetName; private Long model; private Long id; private String name; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java index 483b2a428..ee7c9d4e0 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.headless.api.pojo; import com.google.common.collect.Lists; +import org.apache.commons.collections4.CollectionUtils; import java.util.HashMap; import java.util.List; @@ -30,4 +31,18 @@ public class SchemaMapInfo { public void setMatchedElements(Long dataSet, List elementMatches) { dataSetElementMatches.put(dataSet, elementMatches); } + + public Map generateDataSetMap() { + Map dataSetIdToName = new HashMap<>(); + for (Map.Entry> entry : dataSetElementMatches.entrySet()) { + List values = entry.getValue(); + if (CollectionUtils.isNotEmpty(values)) { + SchemaElementMatch schemaElementMatch = values.get(0); + String dataSetName = schemaElementMatch.getElement().getDataSetName(); + dataSetIdToName.put(entry.getKey(), dataSetName); + } + } + return dataSetIdToName; + } + } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/MapInfoResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/MapInfoResp.java new file mode 100644 index 000000000..9d24a21d7 --- /dev/null +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/MapInfoResp.java @@ -0,0 +1,18 @@ +package com.tencent.supersonic.headless.api.pojo.response; + +import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; +import lombok.Data; + +import java.util.List; +import java.util.Map; + +@Data +public class MapInfoResp { + + private String queryText; + + private Map> mapFields; + + private Map> topFields; + +} diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMapper.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMapper.java index 6ccfa8562..ac28b08df 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMapper.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/BaseMapper.java @@ -51,9 +51,7 @@ public abstract class BaseMapper implements SchemaMapper { AtomicBoolean needAddNew = new AtomicBoolean(true); schemaElementMatches.removeIf( existElementMatch -> { - SchemaElement existElement = existElementMatch.getElement(); - SchemaElement newElement = newElementMatch.getElement(); - if (existElement.equals(newElement)) { + if (isEquals(existElementMatch, newElementMatch)) { if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) { return true; } else { @@ -68,8 +66,20 @@ public abstract class BaseMapper implements SchemaMapper { } } + private static boolean isEquals(SchemaElementMatch existElementMatch, SchemaElementMatch newElementMatch) { + SchemaElement existElement = existElementMatch.getElement(); + SchemaElement newElement = newElementMatch.getElement(); + if (!existElement.equals(newElement)) { + return false; + } + if (SchemaElementType.VALUE.equals(newElement.getType())) { + return existElementMatch.getWord().equalsIgnoreCase(newElementMatch.getWord()); + } + return true; + } + public SchemaElement getSchemaElement(Long dataSetId, SchemaElementType elementType, Long elementID, - SemanticSchema semanticSchema) { + SemanticSchema semanticSchema) { SchemaElement element = new SchemaElement(); DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); if (Objects.isNull(dataSetSchema)) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/KeywordMapper.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/KeywordMapper.java index d703c5d73..d531fe00a 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/KeywordMapper.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/chat/mapper/KeywordMapper.java @@ -6,6 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.response.S2Term; +import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder; import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.chat.knowledge.DatabaseMapResult; import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult; @@ -46,7 +47,7 @@ public class KeywordMapper extends BaseMapper { } private void convertHanlpMapResultToMapInfo(List mapResults, QueryContext queryContext, - List terms) { + List terms) { if (CollectionUtils.isEmpty(mapResults)) { return; } @@ -71,9 +72,6 @@ public class KeywordMapper extends BaseMapper { if (element == null) { continue; } - if (element.getType().equals(SchemaElementType.VALUE)) { - element.setName(hanlpMapResult.getName()); - } Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature); SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() .element(element) @@ -100,7 +98,7 @@ public class KeywordMapper extends BaseMapper { .element(schemaElement) .word(schemaElement.getName()) .detectWord(match.getDetectWord()) - .frequency(10000L) + .frequency(BaseWordBuilder.DEFAULT_FREQUENCY) .similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName())) .build(); log.info("add to schema, elementMatch {}", schemaElementMatch); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java index 64ef10414..930139559 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java @@ -117,6 +117,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { }); SchemaMapInfo mapInfo = queryCtx.getMapInfo(); mapResp.setMapInfo(mapInfo); + mapResp.setQueryText(queryReq.getQueryText()); return mapResp; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java index 5f586154e..5949a5496 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java @@ -12,6 +12,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaValueMap; import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp; import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp; + import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -19,6 +20,7 @@ import java.util.List; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + import org.apache.logging.log4j.util.Strings; import org.springframework.beans.BeanUtils; import org.springframework.util.CollectionUtils; @@ -30,6 +32,7 @@ public class DataSetSchemaBuilder { dataSetSchema.setQueryConfig(resp.getQueryConfig()); SchemaElement dataSet = SchemaElement.builder() .dataSet(resp.getId()) + .dataSetName(resp.getName()) .id(resp.getId()) .name(resp.getName()) .bizName(resp.getBizName()) @@ -66,6 +69,7 @@ public class DataSetSchemaBuilder { if (metric.getIsTag() == 1) { SchemaElement tagToAdd = SchemaElement.builder() .dataSet(resp.getId()) + .dataSetName(resp.getName()) .model(metric.getModelId()) .id(metric.getId()) .name(metric.getName()) @@ -96,6 +100,7 @@ public class DataSetSchemaBuilder { if (dim.getIsTag() == 1) { SchemaElement tagToAdd = SchemaElement.builder() .dataSet(resp.getId()) + .dataSetName(resp.getName()) .model(dim.getModelId()) .id(dim.getId()) .name(dim.getName()) @@ -143,6 +148,7 @@ public class DataSetSchemaBuilder { } SchemaElement dimToAdd = SchemaElement.builder() .dataSet(resp.getId()) + .dataSetName(resp.getName()) .model(dim.getModelId()) .id(dim.getId()) .name(dim.getName()) @@ -174,6 +180,7 @@ public class DataSetSchemaBuilder { } SchemaElement dimValueToAdd = SchemaElement.builder() .dataSet(resp.getId()) + .dataSetName(resp.getName()) .model(dim.getModelId()) .id(dim.getId()) .name(dim.getName()) @@ -195,6 +202,7 @@ public class DataSetSchemaBuilder { SchemaElement metricToAdd = SchemaElement.builder() .dataSet(resp.getId()) + .dataSetName(resp.getName()) .model(metric.getModelId()) .id(metric.getId()) .name(metric.getName())