(improvement)(Headless) Add 'topFields' to the map interface and return dataSet and field names in Chinese. (#892)

This commit is contained in:
lexluo09
2024-04-07 21:29:53 +08:00
committed by GitHub
parent 5f6e9ae194
commit 07b5eb47b6
11 changed files with 175 additions and 14 deletions

View File

@@ -10,6 +10,7 @@ public class ChatParseReq {
private String queryText;
private Integer chatId;
private Integer agentId;
private Integer topN = 10;
private User user;
private QueryFilters queryFilters;
private boolean saveAnswer = true;

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
@@ -17,7 +17,7 @@ public interface ChatService {
List<SearchResult> search(ChatParseReq chatParseReq);
MapResp performMapping(ChatParseReq chatParseReq);
MapInfoResp performMapping(ChatParseReq chatParseReq);
ParseResp performParsing(ChatParseReq chatParseReq);

View File

@@ -17,6 +17,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.chat.server.util.MapInfoConverter;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.chat.server.util.SimilarQueryManager;
import com.tencent.supersonic.common.util.BeanMapper;
@@ -25,6 +26,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
@@ -34,6 +36,7 @@ import com.tencent.supersonic.headless.server.service.SearchService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.List;
@@ -62,7 +65,7 @@ public class ChatServiceImpl implements ChatService {
}
@Override
public MapResp performMapping(ChatParseReq chatParseReq) {
public MapInfoResp performMapping(ChatParseReq chatParseReq) {
return getMapResp(chatParseReq);
}
@@ -110,14 +113,15 @@ public class ChatServiceImpl implements ChatService {
return chatParseContext;
}
private MapResp getMapResp(ChatParseReq chatParseReq) {
private MapInfoResp getMapResp(ChatParseReq chatParseReq) {
ChatParseContext chatParseContext = new ChatParseContext();
BeanMapper.mapper(chatParseReq, chatParseContext);
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
chatParseContext.setAgent(agent);
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
return chatQueryService.performMapping(queryReq);
MapResp mapResp = chatQueryService.performMapping(queryReq);
return MapInfoConverter.convert(mapResp, chatParseReq.getTopN());
}
private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {

View File

@@ -0,0 +1,105 @@
package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
import org.apache.commons.collections4.CollectionUtils;
import org.springframework.beans.BeanUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
public class MapInfoConverter {
public static MapInfoResp convert(MapResp mapResp, Integer topN) {
MapInfoResp mapInfoResp = new MapInfoResp();
if (Objects.isNull(mapResp)) {
return mapInfoResp;
}
BeanUtils.copyProperties(mapResp, mapInfoResp);
Map<Long, String> dataSetMap = mapResp.getMapInfo().generateDataSetMap();
mapInfoResp.setMapFields(getMapFields(mapResp.getMapInfo(), dataSetMap));
mapInfoResp.setTopFields(getTopFields(topN, mapResp.getMapInfo(), dataSetMap));
return mapInfoResp;
}
private static Map<String, List<SchemaElementMatch>> getMapFields(SchemaMapInfo mapInfo,
Map<Long, String> dataSetMap) {
Map<String, List<SchemaElementMatch>> result = new HashMap<>();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : mapInfo.getDataSetElementMatches().entrySet()) {
List<SchemaElementMatch> values = entry.getValue();
if (CollectionUtils.isNotEmpty(values) && dataSetMap.containsKey(entry.getKey())) {
result.put(dataSetMap.get(entry.getKey()), values);
}
}
return result;
}
private static Map<String, List<SchemaElementMatch>> getTopFields(Integer topN,
SchemaMapInfo mapInfo,
Map<Long, String> dataSetMap) {
Set<Long> dataSetIds = mapInfo.getDataSetElementMatches().keySet();
Map<String, List<SchemaElementMatch>> result = new HashMap<>();
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
for (Long dataSetId : dataSetIds) {
String dataSetName = dataSetMap.get(dataSetId);
//topN dimensions
Set<SchemaElementMatch> dimensions = semanticSchema.getDimensions(dataSetId)
.stream().sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(topN - 1).map(mergeFunction()).collect(Collectors.toSet());
SchemaElementMatch timeDimensionMatch = getTimeDimension(dataSetId, dataSetName);
dimensions.add(timeDimensionMatch);
//topN metrics
Set<SchemaElementMatch> metrics = semanticSchema.getMetrics(dataSetId)
.stream().sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(topN).map(mergeFunction()).collect(Collectors.toSet());
dimensions.addAll(metrics);
result.put(dataSetName, new ArrayList<>(dimensions));
}
return result;
}
/***
* get time dimension SchemaElementMatch
* @param dataSetId
* @param dataSetName
* @return
*/
private static SchemaElementMatch getTimeDimension(Long dataSetId, String dataSetName) {
SchemaElement element = SchemaElement.builder().dataSet(dataSetId).dataSetName(dataSetName)
.type(SchemaElementType.DIMENSION).bizName(TimeDimensionEnum.DAY.getName()).build();
SchemaElementMatch timeDimensionMatch = SchemaElementMatch.builder().element(element)
.detectWord(TimeDimensionEnum.DAY.getChName()).word(TimeDimensionEnum.DAY.getChName())
.similarity(1L).frequency(BaseWordBuilder.DEFAULT_FREQUENCY).build();
return timeDimensionMatch;
}
private static Function<SchemaElement, SchemaElementMatch> mergeFunction() {
return schemaElement -> SchemaElementMatch.builder().element(schemaElement)
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY).word(schemaElement.getName()).similarity(1)
.detectWord(schemaElement.getName()).build();
}
}