(improvement)(headless)take only M dimensionValue or N metric/dimension per rond. (#1032)

This commit is contained in:
lexluo09
2024-05-26 23:11:48 +08:00
committed by GitHub
parent 822879cd7b
commit 1fcd880042
5 changed files with 36 additions and 35 deletions

View File

@@ -6,12 +6,6 @@ import com.tencent.supersonic.headless.core.chat.knowledge.HanlpMapResult;
import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService; import com.tencent.supersonic.headless.core.chat.knowledge.KnowledgeBaseService;
import com.tencent.supersonic.headless.core.config.OptimizationConfig; import com.tencent.supersonic.headless.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.pojo.QueryContext; import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
@@ -19,6 +13,11 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
/** /**
* HanlpDictMatchStrategy uses <a href="https://www.hanlp.com/">HanLP</a> to * HanlpDictMatchStrategy uses <a href="https://www.hanlp.com/">HanLP</a> to
@@ -40,7 +39,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
@Override @Override
public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms, public Map<MatchText, List<HanlpMapResult>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) { Set<Long> detectDataSetIds) {
String text = queryContext.getQueryText(); String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null; return null;
@@ -62,7 +61,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
} }
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds, public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectDataSetIds,
String detectSegment, int offset) { String detectSegment, int offset) {
// step1. pre search // step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize(); Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment, LinkedHashSet<HanlpMapResult> hanlpMapResults = knowledgeBaseService.prefixSearch(detectSegment,
@@ -97,19 +96,19 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
return parseResult; return parseResult;
}).collect(Collectors.toCollection(LinkedHashSet::new)); }).collect(Collectors.toCollection(LinkedHashSet::new));
// step5. take only one dimension or 10 metric/dimension value per rond. // step5. take only M dimensionValue or N metric/dimension per rond.
List<HanlpMapResult> dimensionMetrics = hanlpMapResults.stream() List<HanlpMapResult> dimensionValues = hanlpMapResults.stream()
.filter(entry -> mapperHelper.existDimensionValues(entry.getNatures())) .filter(entry -> mapperHelper.existDimensionValues(entry.getNatures()))
.collect(Collectors.toList()) .limit(optimizationConfig.getOneDetectionDimensionValueSize())
.stream()
.limit(1)
.collect(Collectors.toList()); .collect(Collectors.toList());
Integer oneDetectionSize = optimizationConfig.getOneDetectionSize(); Integer oneDetectionSize = optimizationConfig.getOneDetectionSize();
List<HanlpMapResult> oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize) List<HanlpMapResult> oneRoundResults = hanlpMapResults.stream().limit(oneDetectionSize)
.collect(Collectors.toList()); .collect(Collectors.toList());
if (CollectionUtils.isNotEmpty(dimensionMetrics)) {
oneRoundResults = dimensionMetrics; // add the dimensionValue/term if it exists dimensionValue
if (CollectionUtils.isNotEmpty(dimensionValues)) {
oneRoundResults = dimensionValues;
List<HanlpMapResult> termOneRoundResults = hanlpMapResults.stream() List<HanlpMapResult> termOneRoundResults = hanlpMapResults.stream()
.filter(hanlpMapResult -> mapperHelper.existTerms(hanlpMapResult.getNatures())) .filter(hanlpMapResult -> mapperHelper.existTerms(hanlpMapResult.getNatures()))
.collect(Collectors.toList()); .collect(Collectors.toList());

View File

@@ -16,10 +16,12 @@ public class OptimizationConfig {
@Value("${s2.one.detection.size:8}") @Value("${s2.one.detection.size:8}")
private Integer oneDetectionSize; private Integer oneDetectionSize;
@Value("${s2.one.detection.max.size:20}") @Value("${s2.one.detection.max.size:20}")
private Integer oneDetectionMaxSize; private Integer oneDetectionMaxSize;
@Value("${s2.one.detection.dimensionValue.size:1}")
private Integer oneDetectionDimensionValueSize;
@Value("${s2.metric.dimension.min.threshold:0.3}") @Value("${s2.metric.dimension.min.threshold:0.3}")
private Double metricDimensionMinThresholdConfig; private Double metricDimensionMinThresholdConfig;

View File

@@ -69,7 +69,7 @@ public class DimensionServiceImpl implements DimensionService {
private ModelService modelService; private ModelService modelService;
private AliasGenerateHelper chatGptHelper; private AliasGenerateHelper aliasGenerateHelper;
private DatabaseService databaseService; private DatabaseService databaseService;
@@ -85,14 +85,14 @@ public class DimensionServiceImpl implements DimensionService {
public DimensionServiceImpl(DimensionRepository dimensionRepository, public DimensionServiceImpl(DimensionRepository dimensionRepository,
ModelService modelService, ModelService modelService,
AliasGenerateHelper chatGptHelper, AliasGenerateHelper aliasGenerateHelper,
DatabaseService databaseService, DatabaseService databaseService,
ModelRelaService modelRelaService, ModelRelaService modelRelaService,
DataSetService dataSetService, DataSetService dataSetService,
TagMetaService tagMetaService) { TagMetaService tagMetaService) {
this.modelService = modelService; this.modelService = modelService;
this.dimensionRepository = dimensionRepository; this.dimensionRepository = dimensionRepository;
this.chatGptHelper = chatGptHelper; this.aliasGenerateHelper = aliasGenerateHelper;
this.databaseService = databaseService; this.databaseService = databaseService;
this.modelRelaService = modelRelaService; this.modelRelaService = modelRelaService;
this.dataSetService = dataSetService; this.dataSetService = dataSetService;
@@ -341,8 +341,8 @@ public class DimensionServiceImpl implements DimensionService {
@Override @Override
public List<String> mockAlias(DimensionReq dimensionReq, String mockType, User user) { public List<String> mockAlias(DimensionReq dimensionReq, String mockType, User user) {
String mockAlias = chatGptHelper.generateAlias(mockType, dimensionReq.getName(), dimensionReq.getBizName(), String mockAlias = aliasGenerateHelper.generateAlias(mockType, dimensionReq.getName(),
"", dimensionReq.getDescription(), false); dimensionReq.getBizName(), "", dimensionReq.getDescription(), false);
return JSONObject.parseObject(mockAlias, new TypeReference<List<String>>() { return JSONObject.parseObject(mockAlias, new TypeReference<List<String>>() {
}); });
} }
@@ -363,7 +363,7 @@ public class DimensionServiceImpl implements DimensionService {
String value = (String) stringObjectMap.get(dimensionReq.getBizName()); String value = (String) stringObjectMap.get(dimensionReq.getBizName());
valueList.add(value); valueList.add(value);
} }
String json = chatGptHelper.generateDimensionValueAlias(JSON.toJSONString(valueList)); String json = aliasGenerateHelper.generateDimensionValueAlias(JSON.toJSONString(valueList));
log.info("return llm res is :{}", json); log.info("return llm res is :{}", json);
JSONObject jsonObject = JSON.parseObject(json); JSONObject jsonObject = JSON.parseObject(json);

View File

@@ -97,7 +97,7 @@ public class MetricServiceImpl implements MetricService {
private DimensionService dimensionService; private DimensionService dimensionService;
private AliasGenerateHelper chatGptHelper; private AliasGenerateHelper aliasGenerateHelper;
private CollectService collectService; private CollectService collectService;
@@ -111,7 +111,7 @@ public class MetricServiceImpl implements MetricService {
public MetricServiceImpl(MetricRepository metricRepository, public MetricServiceImpl(MetricRepository metricRepository,
ModelService modelService, ModelService modelService,
AliasGenerateHelper chatGptHelper, AliasGenerateHelper aliasGenerateHelper,
CollectService collectService, CollectService collectService,
DataSetService dataSetService, DataSetService dataSetService,
ApplicationEventPublisher eventPublisher, ApplicationEventPublisher eventPublisher,
@@ -120,7 +120,7 @@ public class MetricServiceImpl implements MetricService {
@Lazy MetaDiscoveryService metaDiscoveryService) { @Lazy MetaDiscoveryService metaDiscoveryService) {
this.metricRepository = metricRepository; this.metricRepository = metricRepository;
this.modelService = modelService; this.modelService = modelService;
this.chatGptHelper = chatGptHelper; this.aliasGenerateHelper = aliasGenerateHelper;
this.eventPublisher = eventPublisher; this.eventPublisher = eventPublisher;
this.collectService = collectService; this.collectService = collectService;
this.dataSetService = dataSetService; this.dataSetService = dataSetService;
@@ -535,7 +535,7 @@ public class MetricServiceImpl implements MetricService {
@Override @Override
public List<String> mockAlias(MetricBaseReq metricReq, String mockType, User user) { public List<String> mockAlias(MetricBaseReq metricReq, String mockType, User user) {
String mockAlias = chatGptHelper.generateAlias(mockType, metricReq.getName(), metricReq.getBizName(), "", String mockAlias = aliasGenerateHelper.generateAlias(mockType, metricReq.getName(), metricReq.getBizName(), "",
metricReq.getDescription(), !"".equals(metricReq.getDataFormatType())); metricReq.getDescription(), !"".equals(metricReq.getDataFormatType()));
return JSONObject.parseObject(mockAlias, new TypeReference<List<String>>() { return JSONObject.parseObject(mockAlias, new TypeReference<List<String>>() {
}); });

View File

@@ -1,5 +1,8 @@
package com.tencent.supersonic.headless.server.service; package com.tencent.supersonic.headless.server.service;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.DataFormat; import com.tencent.supersonic.common.pojo.DataFormat;
@@ -7,7 +10,6 @@ import com.tencent.supersonic.common.pojo.enums.DataFormatTypeEnum;
import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum; import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum;
import com.tencent.supersonic.common.pojo.enums.StatusEnum; import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.headless.server.utils.AliasGenerateHelper;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension; import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.MeasureParam; import com.tencent.supersonic.headless.api.pojo.MeasureParam;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams; import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams;
@@ -21,17 +23,14 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO;
import com.tencent.supersonic.headless.server.persistence.repository.MetricRepository; import com.tencent.supersonic.headless.server.persistence.repository.MetricRepository;
import com.tencent.supersonic.headless.server.service.impl.DataSetServiceImpl; import com.tencent.supersonic.headless.server.service.impl.DataSetServiceImpl;
import com.tencent.supersonic.headless.server.service.impl.MetricServiceImpl; import com.tencent.supersonic.headless.server.service.impl.MetricServiceImpl;
import com.tencent.supersonic.headless.server.utils.AliasGenerateHelper;
import com.tencent.supersonic.headless.server.utils.MetricConverter; import com.tencent.supersonic.headless.server.utils.MetricConverter;
import java.util.HashMap;
import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.Mockito; import org.mockito.Mockito;
import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisher;
import java.util.HashMap;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
public class MetricServiceImplTest { public class MetricServiceImplTest {
@Test @Test
@@ -64,15 +63,16 @@ public class MetricServiceImplTest {
private MetricService mockMetricService(MetricRepository metricRepository, private MetricService mockMetricService(MetricRepository metricRepository,
ModelService modelService) { ModelService modelService) {
AliasGenerateHelper chatGptHelper = Mockito.mock(AliasGenerateHelper.class); AliasGenerateHelper aliasGenerateHelper = Mockito.mock(AliasGenerateHelper.class);
CollectService collectService = Mockito.mock(CollectService.class); CollectService collectService = Mockito.mock(CollectService.class);
ApplicationEventPublisher eventPublisher = Mockito.mock(ApplicationEventPublisher.class); ApplicationEventPublisher eventPublisher = Mockito.mock(ApplicationEventPublisher.class);
DataSetService dataSetService = Mockito.mock(DataSetServiceImpl.class); DataSetService dataSetService = Mockito.mock(DataSetServiceImpl.class);
DimensionService dimensionService = Mockito.mock(DimensionService.class); DimensionService dimensionService = Mockito.mock(DimensionService.class);
TagMetaService tagMetaService = Mockito.mock(TagMetaService.class); TagMetaService tagMetaService = Mockito.mock(TagMetaService.class);
MetaDiscoveryService metaDiscoveryService = Mockito.mock(MetaDiscoveryService.class); MetaDiscoveryService metaDiscoveryService = Mockito.mock(MetaDiscoveryService.class);
return new MetricServiceImpl(metricRepository, modelService, chatGptHelper, collectService, dataSetService, return new MetricServiceImpl(metricRepository, modelService, aliasGenerateHelper,
eventPublisher, dimensionService, tagMetaService, metaDiscoveryService); collectService, dataSetService, eventPublisher, dimensionService,
tagMetaService, metaDiscoveryService);
} }
private MetricReq buildMetricReq() { private MetricReq buildMetricReq() {