mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 02:46:56 +00:00
(improvement)(Headless) Add 'topFields' to the map interface and return dataSet and field names in Chinese. (#892)
This commit is contained in:
@@ -10,6 +10,7 @@ public class ChatParseReq {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Integer agentId;
|
||||
private Integer topN = 10;
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
@@ -17,7 +17,7 @@ public interface ChatService {
|
||||
|
||||
List<SearchResult> search(ChatParseReq chatParseReq);
|
||||
|
||||
MapResp performMapping(ChatParseReq chatParseReq);
|
||||
MapInfoResp performMapping(ChatParseReq chatParseReq);
|
||||
|
||||
ParseResp performParsing(ChatParseReq chatParseReq);
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.util.MapInfoConverter;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.chat.server.util.SimilarQueryManager;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
@@ -25,6 +26,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
@@ -34,6 +36,7 @@ import com.tencent.supersonic.headless.server.service.SearchService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@@ -62,7 +65,7 @@ public class ChatServiceImpl implements ChatService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public MapResp performMapping(ChatParseReq chatParseReq) {
|
||||
public MapInfoResp performMapping(ChatParseReq chatParseReq) {
|
||||
return getMapResp(chatParseReq);
|
||||
}
|
||||
|
||||
@@ -110,14 +113,15 @@ public class ChatServiceImpl implements ChatService {
|
||||
return chatParseContext;
|
||||
}
|
||||
|
||||
private MapResp getMapResp(ChatParseReq chatParseReq) {
|
||||
private MapInfoResp getMapResp(ChatParseReq chatParseReq) {
|
||||
ChatParseContext chatParseContext = new ChatParseContext();
|
||||
BeanMapper.mapper(chatParseReq, chatParseContext);
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||
chatParseContext.setAgent(agent);
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
return chatQueryService.performMapping(queryReq);
|
||||
MapResp mapResp = chatQueryService.performMapping(queryReq);
|
||||
return MapInfoConverter.convert(mapResp, chatParseReq.getTopN());
|
||||
}
|
||||
|
||||
private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.TimeDimensionEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class MapInfoConverter {
|
||||
|
||||
public static MapInfoResp convert(MapResp mapResp, Integer topN) {
|
||||
MapInfoResp mapInfoResp = new MapInfoResp();
|
||||
if (Objects.isNull(mapResp)) {
|
||||
return mapInfoResp;
|
||||
}
|
||||
BeanUtils.copyProperties(mapResp, mapInfoResp);
|
||||
Map<Long, String> dataSetMap = mapResp.getMapInfo().generateDataSetMap();
|
||||
mapInfoResp.setMapFields(getMapFields(mapResp.getMapInfo(), dataSetMap));
|
||||
mapInfoResp.setTopFields(getTopFields(topN, mapResp.getMapInfo(), dataSetMap));
|
||||
return mapInfoResp;
|
||||
}
|
||||
|
||||
private static Map<String, List<SchemaElementMatch>> getMapFields(SchemaMapInfo mapInfo,
|
||||
Map<Long, String> dataSetMap) {
|
||||
Map<String, List<SchemaElementMatch>> result = new HashMap<>();
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : mapInfo.getDataSetElementMatches().entrySet()) {
|
||||
List<SchemaElementMatch> values = entry.getValue();
|
||||
if (CollectionUtils.isNotEmpty(values) && dataSetMap.containsKey(entry.getKey())) {
|
||||
result.put(dataSetMap.get(entry.getKey()), values);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private static Map<String, List<SchemaElementMatch>> getTopFields(Integer topN,
|
||||
SchemaMapInfo mapInfo,
|
||||
Map<Long, String> dataSetMap) {
|
||||
Set<Long> dataSetIds = mapInfo.getDataSetElementMatches().keySet();
|
||||
Map<String, List<SchemaElementMatch>> result = new HashMap<>();
|
||||
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
SemanticSchema semanticSchema = semanticService.getSemanticSchema();
|
||||
for (Long dataSetId : dataSetIds) {
|
||||
String dataSetName = dataSetMap.get(dataSetId);
|
||||
|
||||
//topN dimensions
|
||||
Set<SchemaElementMatch> dimensions = semanticSchema.getDimensions(dataSetId)
|
||||
.stream().sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(topN - 1).map(mergeFunction()).collect(Collectors.toSet());
|
||||
|
||||
SchemaElementMatch timeDimensionMatch = getTimeDimension(dataSetId, dataSetName);
|
||||
dimensions.add(timeDimensionMatch);
|
||||
|
||||
//topN metrics
|
||||
Set<SchemaElementMatch> metrics = semanticSchema.getMetrics(dataSetId)
|
||||
.stream().sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(topN).map(mergeFunction()).collect(Collectors.toSet());
|
||||
|
||||
dimensions.addAll(metrics);
|
||||
result.put(dataSetName, new ArrayList<>(dimensions));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/***
|
||||
* get time dimension SchemaElementMatch
|
||||
* @param dataSetId
|
||||
* @param dataSetName
|
||||
* @return
|
||||
*/
|
||||
private static SchemaElementMatch getTimeDimension(Long dataSetId, String dataSetName) {
|
||||
SchemaElement element = SchemaElement.builder().dataSet(dataSetId).dataSetName(dataSetName)
|
||||
.type(SchemaElementType.DIMENSION).bizName(TimeDimensionEnum.DAY.getName()).build();
|
||||
|
||||
SchemaElementMatch timeDimensionMatch = SchemaElementMatch.builder().element(element)
|
||||
.detectWord(TimeDimensionEnum.DAY.getChName()).word(TimeDimensionEnum.DAY.getChName())
|
||||
.similarity(1L).frequency(BaseWordBuilder.DEFAULT_FREQUENCY).build();
|
||||
|
||||
return timeDimensionMatch;
|
||||
}
|
||||
|
||||
private static Function<SchemaElement, SchemaElementMatch> mergeFunction() {
|
||||
return schemaElement -> SchemaElementMatch.builder().element(schemaElement)
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY).word(schemaElement.getName()).similarity(1)
|
||||
.detectWord(schemaElement.getName()).build();
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,7 @@ import java.util.List;
|
||||
public class SchemaElement implements Serializable {
|
||||
|
||||
private Long dataSet;
|
||||
private String dataSetName;
|
||||
private Long model;
|
||||
private Long id;
|
||||
private String name;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.api.pojo;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
@@ -30,4 +31,18 @@ public class SchemaMapInfo {
|
||||
public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) {
|
||||
dataSetElementMatches.put(dataSet, elementMatches);
|
||||
}
|
||||
|
||||
public Map<Long, String> generateDataSetMap() {
|
||||
Map<Long, String> dataSetIdToName = new HashMap<>();
|
||||
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {
|
||||
List<SchemaElementMatch> values = entry.getValue();
|
||||
if (CollectionUtils.isNotEmpty(values)) {
|
||||
SchemaElementMatch schemaElementMatch = values.get(0);
|
||||
String dataSetName = schemaElementMatch.getElement().getDataSetName();
|
||||
dataSetIdToName.put(entry.getKey(), dataSetName);
|
||||
}
|
||||
}
|
||||
return dataSetIdToName;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
public class MapInfoResp {
|
||||
|
||||
private String queryText;
|
||||
|
||||
private Map<String, List<SchemaElementMatch>> mapFields;
|
||||
|
||||
private Map<String, List<SchemaElementMatch>> topFields;
|
||||
|
||||
}
|
||||
@@ -51,9 +51,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
AtomicBoolean needAddNew = new AtomicBoolean(true);
|
||||
schemaElementMatches.removeIf(
|
||||
existElementMatch -> {
|
||||
SchemaElement existElement = existElementMatch.getElement();
|
||||
SchemaElement newElement = newElementMatch.getElement();
|
||||
if (existElement.equals(newElement)) {
|
||||
if (isEquals(existElementMatch, newElementMatch)) {
|
||||
if (newElementMatch.getSimilarity() > existElementMatch.getSimilarity()) {
|
||||
return true;
|
||||
} else {
|
||||
@@ -68,8 +66,20 @@ public abstract class BaseMapper implements SchemaMapper {
|
||||
}
|
||||
}
|
||||
|
||||
private static boolean isEquals(SchemaElementMatch existElementMatch, SchemaElementMatch newElementMatch) {
|
||||
SchemaElement existElement = existElementMatch.getElement();
|
||||
SchemaElement newElement = newElementMatch.getElement();
|
||||
if (!existElement.equals(newElement)) {
|
||||
return false;
|
||||
}
|
||||
if (SchemaElementType.VALUE.equals(newElement.getType())) {
|
||||
return existElementMatch.getWord().equalsIgnoreCase(newElementMatch.getWord());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
public SchemaElement getSchemaElement(Long dataSetId, SchemaElementType elementType, Long elementID,
|
||||
SemanticSchema semanticSchema) {
|
||||
SemanticSchema semanticSchema) {
|
||||
SchemaElement element = new SchemaElement();
|
||||
DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId);
|
||||
if (Objects.isNull(dataSetSchema)) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.builder.BaseWordBuilder;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.DatabaseMapResult;
|
||||
import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
|
||||
@@ -46,7 +47,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
}
|
||||
|
||||
private void convertHanlpMapResultToMapInfo(List<HanlpMapResult> mapResults, QueryContext queryContext,
|
||||
List<S2Term> terms) {
|
||||
List<S2Term> terms) {
|
||||
if (CollectionUtils.isEmpty(mapResults)) {
|
||||
return;
|
||||
}
|
||||
@@ -71,9 +72,6 @@ public class KeywordMapper extends BaseMapper {
|
||||
if (element == null) {
|
||||
continue;
|
||||
}
|
||||
if (element.getType().equals(SchemaElementType.VALUE)) {
|
||||
element.setName(hanlpMapResult.getName());
|
||||
}
|
||||
Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature);
|
||||
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
|
||||
.element(element)
|
||||
@@ -100,7 +98,7 @@ public class KeywordMapper extends BaseMapper {
|
||||
.element(schemaElement)
|
||||
.word(schemaElement.getName())
|
||||
.detectWord(match.getDetectWord())
|
||||
.frequency(10000L)
|
||||
.frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
|
||||
.similarity(mapperHelper.getSimilarity(match.getDetectWord(), schemaElement.getName()))
|
||||
.build();
|
||||
log.info("add to schema, elementMatch {}", schemaElementMatch);
|
||||
|
||||
@@ -117,6 +117,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
});
|
||||
SchemaMapInfo mapInfo = queryCtx.getMapInfo();
|
||||
mapResp.setMapInfo(mapInfo);
|
||||
mapResp.setQueryText(queryReq.getQueryText());
|
||||
return mapResp;
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaValueMap;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricSchemaResp;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashSet;
|
||||
@@ -19,6 +20,7 @@ import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -30,6 +32,7 @@ public class DataSetSchemaBuilder {
|
||||
dataSetSchema.setQueryConfig(resp.getQueryConfig());
|
||||
SchemaElement dataSet = SchemaElement.builder()
|
||||
.dataSet(resp.getId())
|
||||
.dataSetName(resp.getName())
|
||||
.id(resp.getId())
|
||||
.name(resp.getName())
|
||||
.bizName(resp.getBizName())
|
||||
@@ -66,6 +69,7 @@ public class DataSetSchemaBuilder {
|
||||
if (metric.getIsTag() == 1) {
|
||||
SchemaElement tagToAdd = SchemaElement.builder()
|
||||
.dataSet(resp.getId())
|
||||
.dataSetName(resp.getName())
|
||||
.model(metric.getModelId())
|
||||
.id(metric.getId())
|
||||
.name(metric.getName())
|
||||
@@ -96,6 +100,7 @@ public class DataSetSchemaBuilder {
|
||||
if (dim.getIsTag() == 1) {
|
||||
SchemaElement tagToAdd = SchemaElement.builder()
|
||||
.dataSet(resp.getId())
|
||||
.dataSetName(resp.getName())
|
||||
.model(dim.getModelId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
@@ -143,6 +148,7 @@ public class DataSetSchemaBuilder {
|
||||
}
|
||||
SchemaElement dimToAdd = SchemaElement.builder()
|
||||
.dataSet(resp.getId())
|
||||
.dataSetName(resp.getName())
|
||||
.model(dim.getModelId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
@@ -174,6 +180,7 @@ public class DataSetSchemaBuilder {
|
||||
}
|
||||
SchemaElement dimValueToAdd = SchemaElement.builder()
|
||||
.dataSet(resp.getId())
|
||||
.dataSetName(resp.getName())
|
||||
.model(dim.getModelId())
|
||||
.id(dim.getId())
|
||||
.name(dim.getName())
|
||||
@@ -195,6 +202,7 @@ public class DataSetSchemaBuilder {
|
||||
|
||||
SchemaElement metricToAdd = SchemaElement.builder()
|
||||
.dataSet(resp.getId())
|
||||
.dataSetName(resp.getName())
|
||||
.model(metric.getModelId())
|
||||
.id(metric.getId())
|
||||
.name(metric.getName())
|
||||
|
||||
Reference in New Issue
Block a user