[improvement][Headless] Embedding supports Chinese by default and fixes the issue of abnormal number recognition (#726)

This commit is contained in:
lexluo09
2024-02-18 19:51:19 +08:00
committed by GitHub
parent 39158d6877
commit fdb69547e6
19 changed files with 62 additions and 59 deletions

View File

@@ -52,7 +52,7 @@ public class OptimizationConfig {
@Value("${embedding.mapper.round.number:10}") @Value("${embedding.mapper.round.number:10}")
private int embeddingMapperRoundNumber; private int embeddingMapperRoundNumber;
@Value("${embedding.mapper.distance.threshold:0.58}") @Value("${embedding.mapper.distance.threshold:0.01}")
private Double embeddingMapperDistanceThreshold; private Double embeddingMapperDistanceThreshold;
@Value("${s2SQL.linking.value.switch:true}") @Value("${s2SQL.linking.value.switch:true}")

View File

@@ -3,12 +3,6 @@ package com.tencent.supersonic.chat.core.mapper;
import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper; import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
@@ -19,6 +13,11 @@ import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@Service @Service
@Slf4j @Slf4j
@@ -29,7 +28,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
@Override @Override
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms, public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
Set<Long> detectViewIds) { Set<Long> detectViewIds) {
String text = queryContext.getQueryText(); String text = queryContext.getQueryText();
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
return null; return null;
@@ -57,9 +56,9 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
int offset = mapperHelper.getStepOffset(terms, startIndex); int offset = mapperHelper.getStepOffset(terms, startIndex);
index = mapperHelper.getStepIndex(regOffsetToLength, index); index = mapperHelper.getStepIndex(regOffsetToLength, index);
if (index <= text.length()) { if (index <= text.length()) {
String detectSegment = text.substring(startIndex, index); String detectSegment = text.substring(startIndex, index).trim();
detectSegments.add(detectSegment); detectSegments.add(detectSegment);
detectByStep(queryContext, results, detectViewIds, startIndex, index, offset); detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
} }
} }
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex); startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
@@ -151,7 +150,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
public abstract String getMapKey(T a); public abstract String getMapKey(T a);
public abstract void detectByStep(QueryContext queryContext, Set<T> results, public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
Set<Long> detectViewIds, Integer startIndex, Integer index, int offset); String detectSegment, int offset);
} }

View File

@@ -55,15 +55,12 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
} }
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds, public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
Integer startIndex, Integer index, int offset) { String detectSegment, int offset) {
String detectSegment = queryContext.getQueryText().substring(startIndex, index);
if (StringUtils.isBlank(detectSegment)) { if (StringUtils.isBlank(detectSegment)) {
return; return;
} }
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
Double metricDimensionThresholdConfig = getThreshold(queryContext); Double metricDimensionThresholdConfig = getThreshold(queryContext);
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements); Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) { for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
@@ -73,9 +70,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
continue; continue;
} }
Set<SchemaElement> schemaElements = entry.getValue(); Set<SchemaElement> schemaElements = entry.getValue();
if (!CollectionUtils.isEmpty(viewIds)) { if (!CollectionUtils.isEmpty(detectViewIds)) {
schemaElements = schemaElements.stream() schemaElements = schemaElements.stream()
.filter(schemaElement -> viewIds.contains(schemaElement.getView())) .filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
.collect(Collectors.toSet()); .collect(Collectors.toSet());
} }
for (SchemaElement schemaElement : schemaElements) { for (SchemaElement schemaElement : schemaElements) {

View File

@@ -1,11 +1,11 @@
package com.tencent.supersonic.chat.core.mapper; package com.tencent.supersonic.chat.core.mapper;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval; import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.response.S2Term; import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder; import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
@@ -34,14 +34,12 @@ public class EmbeddingMapper extends BaseMapper {
//2. build SchemaElementMatch by info //2. build SchemaElementMatch by info
for (EmbeddingResult matchResult : matchResults) { for (EmbeddingResult matchResult : matchResults) {
Long elementId = Retrieval.getLongId(matchResult.getId()); Long elementId = Retrieval.getLongId(matchResult.getId());
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
SchemaElement.class);
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId")); Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
if (Objects.isNull(viewId)) { if (Objects.isNull(viewId)) {
continue; continue;
} }
schemaElement = getSchemaElement(viewId, schemaElement.getType(), elementId, SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
queryContext.getSemanticSchema()); queryContext.getSemanticSchema());
if (schemaElement == null) { if (schemaElement == null) {
continue; continue;

View File

@@ -2,12 +2,12 @@ package com.tencent.supersonic.chat.core.mapper;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.core.config.OptimizationConfig; import com.tencent.supersonic.chat.core.config.OptimizationConfig;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.chat.core.pojo.QueryContext; import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.util.embedding.Retrieval; import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery; import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult; import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService; import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
@@ -47,6 +47,12 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
return a.getName() + Constants.UNDERLINE + a.getId(); return a.getName() + Constants.UNDERLINE + a.getId();
} }
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
String detectSegment, int offset) {
}
@Override @Override
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds, protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
Set<String> detectSegments) { Set<String> detectSegments) {
@@ -111,9 +117,4 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
selectResultInOneRound(results, oneRoundResults); selectResultInOneRound(results, oneRoundResults);
} }
@Override
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
Integer startIndex, Integer index, int offset) {
return;
}
} }

View File

@@ -58,10 +58,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
} }
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds, public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
Integer startIndex, Integer index, int offset) { String detectSegment, int offset) {
String text = queryContext.getQueryText();
String detectSegment = text.substring(startIndex, index);
// step1. pre search // step1. pre search
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize(); Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize, LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,

View File

@@ -88,9 +88,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
} }
@Override @Override
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectViewIds, public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
Integer startIndex, String detectSegment, int offset) {
Integer i, int offset) {
} }
} }

View File

@@ -32,8 +32,7 @@ public class SchemaDictUpdateListener implements ApplicationListener<DataEvent>
DictWord dictWord = new DictWord(); DictWord dictWord = new DictWord();
dictWord.setWord(dataItem.getName()); dictWord.setWord(dataItem.getName());
String sign = DictWordType.NATURE_SPILT; String sign = DictWordType.NATURE_SPILT;
String nature = sign + 1 + sign + dataItem.getId() String nature = sign + 1 + sign + dataItem.getId() + dataItem.getType().name().toLowerCase();
+ sign + dataItem.getType().name().toLowerCase();
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY; String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
dictWord.setNature(nature); dictWord.setNature(nature);
dictWord.setNatureWithFrequency(natureWithFrequency); dictWord.setNatureWithFrequency(natureWithFrequency);

View File

@@ -8,7 +8,10 @@ import lombok.Data;
@Builder @Builder
public class DataItem { public class DataItem {
private Long id; /***
* This field uses an underscore (_) at the end.
*/
private String id;
private String bizName; private String bizName;
@@ -18,6 +21,9 @@ public class DataItem {
private TypeEnums type; private TypeEnums type;
/***
* This field uses an underscore (_) at the end.
*/
private String modelId; private String modelId;
private String defaultAgg; private String defaultAgg;

View File

@@ -75,7 +75,7 @@ public class SysParameter {
parameters.add(new Parameter("embedding.mapper.number", "5", parameters.add(new Parameter("embedding.mapper.number", "5",
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置")); "批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
parameters.add(new Parameter("embedding.mapper.distance.threshold", parameters.add(new Parameter("embedding.mapper.distance.threshold",
"0.58", "向量召回相似度阈值", "相似度大于该阈值的则舍弃", "number", "Mapper相关配置")); "0.01", "向量召回相似度阈值", "相似度大于该阈值的则舍弃", "number", "Mapper相关配置"));
//parser config //parser config
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT", Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT",

View File

@@ -129,7 +129,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
List<Retrieval> retrievals = new ArrayList<>(); List<Retrieval> retrievals = new ArrayList<>();
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) { for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
Retrieval retrieval = new Retrieval(); Retrieval retrieval = new Retrieval();
retrieval.setDistance(embeddingMatch.score()); retrieval.setDistance(1 - embeddingMatch.score());
retrieval.setId(embeddingMatch.embeddingId()); retrieval.setId(embeddingMatch.embeddingId());
retrieval.setQuery(embeddingMatch.embedded().getQuery()); retrieval.setQuery(embeddingMatch.embedded().getQuery());
Map<String, Object> metadata = embeddingMatch.embedded().getMetadata(); Map<String, Object> metadata = embeddingMatch.embedded().getMetadata();

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.server.listener;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataEvent; import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.enums.EventType; import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.common.util.ComponentFactory; import com.tencent.supersonic.common.util.ComponentFactory;
@@ -43,8 +42,7 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
.map(dataItem -> { .map(dataItem -> {
EmbeddingQuery embeddingQuery = new EmbeddingQuery(); EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId( embeddingQuery.setQueryId(
dataItem.getId().toString() + Constants.UNDERLINE dataItem.getId() + dataItem.getType().name().toLowerCase());
+ dataItem.getType().name().toLowerCase());
embeddingQuery.setQuery(dataItem.getName()); embeddingQuery.setQuery(dataItem.getName());
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class); Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
embeddingQuery.setMetadata(meta); embeddingQuery.setMetadata(meta);

View File

@@ -132,7 +132,7 @@ public class DimensionServiceImpl implements DimensionService {
if (!oldName.equals(dimensionDO.getName())) { if (!oldName.equals(dimensionDO.getName())) {
sendEvent(DataItem.builder().modelId(dimensionDO.getModelId() + Constants.UNDERLINE) sendEvent(DataItem.builder().modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
.newName(dimensionReq.getName()).name(oldName).type(TypeEnums.DIMENSION) .newName(dimensionReq.getName()).name(oldName).type(TypeEnums.DIMENSION)
.id(dimensionDO.getId()).build(), EventType.UPDATE); .id(dimensionDO.getId() + Constants.UNDERLINE).build(), EventType.UPDATE);
} }
} }
@@ -366,8 +366,9 @@ public class DimensionServiceImpl implements DimensionService {
private void sendEventBatch(List<DimensionDO> dimensionDOS, EventType eventType) { private void sendEventBatch(List<DimensionDO> dimensionDOS, EventType eventType) {
List<DataItem> dataItems = dimensionDOS.stream() List<DataItem> dataItems = dimensionDOS.stream()
.map(dimensionDO -> DataItem.builder().id(dimensionDO.getId()).name(dimensionDO.getName()) .map(dimensionDO -> DataItem.builder().id(dimensionDO.getId() + Constants.UNDERLINE)
.modelId(dimensionDO.getModelId() + Constants.UNDERLINE).type(TypeEnums.DIMENSION).build()) .name(dimensionDO.getName()).modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
.type(TypeEnums.DIMENSION).build())
.collect(Collectors.toList()); .collect(Collectors.toList());
eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType)); eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType));
} }

View File

@@ -449,7 +449,7 @@ public class MetricServiceImpl implements MetricService {
private DataItem getDataItem(MetricDO metricDO) { private DataItem getDataItem(MetricDO metricDO) {
MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO, MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO,
new HashMap<>(), Lists.newArrayList()); new HashMap<>(), Lists.newArrayList());
return DataItem.builder().id(metricDO.getId()).name(metricDO.getName()) return DataItem.builder().id(metricDO.getId() + Constants.UNDERLINE).name(metricDO.getName())
.bizName(metricDO.getBizName()) .bizName(metricDO.getBizName())
.modelId(metricDO.getModelId() + Constants.UNDERLINE) .modelId(metricDO.getModelId() + Constants.UNDERLINE)
.type(TypeEnums.METRIC).defaultAgg(metricResp.getDefaultAgg()).build(); .type(TypeEnums.METRIC).defaultAgg(metricResp.getDefaultAgg()).build();

View File

@@ -51,7 +51,7 @@
<dependency> <dependency>
<groupId>dev.langchain4j</groupId> <groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId> <artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
</dependency> </dependency>
</dependencies> </dependencies>

View File

@@ -5,9 +5,9 @@ import static dev.langchain4j.exception.IllegalConfigurationException.illegalCon
import static dev.langchain4j.internal.Utils.isNullOrBlank; import static dev.langchain4j.internal.Utils.isNullOrBlank;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel; import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
import dev.langchain4j.model.huggingface.HuggingFaceChatModel; import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel; import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel; import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
@@ -248,7 +248,7 @@ public class S2LangChain4jAutoConfiguration {
case IN_PROCESS: case IN_PROCESS:
InProcess inProcess = properties.getEmbeddingModel().getInProcess(); InProcess inProcess = properties.getEmbeddingModel().getInProcess();
if (Objects.isNull(inProcess) || isNullOrBlank(inProcess.getModelPath())) { if (Objects.isNull(inProcess) || isNullOrBlank(inProcess.getModelPath())) {
return new AllMiniLmL6V2EmbeddingModel(); return new BgeSmallZhEmbeddingModel();
} }
return new S2OnnxEmbeddingModel(inProcess.getModelPath(), inProcess.getVocabularyPath()); return new S2OnnxEmbeddingModel(inProcess.getModelPath(), inProcess.getVocabularyPath());

View File

@@ -1,5 +1,8 @@
package com.tencent.supersonic.chat; package com.tencent.supersonic.chat;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
@@ -8,21 +11,17 @@ import com.tencent.supersonic.chat.core.query.rule.metric.MetricFilterQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricGroupByQuery; import com.tencent.supersonic.chat.core.query.rule.metric.MetricGroupByQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery; import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.chat.core.query.rule.metric.MetricTopNQuery; import com.tencent.supersonic.chat.core.query.rule.metric.MetricTopNQuery;
import com.tencent.supersonic.util.DataUtils;
import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.common.pojo.enums.QueryType;
import org.junit.Assert; import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.Test;
import java.text.DateFormat; import java.text.DateFormat;
import java.text.SimpleDateFormat; import java.text.SimpleDateFormat;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.junit.Assert;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE; import org.junit.jupiter.api.Test;
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
public class MetricTest extends BaseTest { public class MetricTest extends BaseTest {
@@ -41,6 +40,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name", expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
@@ -76,6 +76,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay)); expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.METRIC); expectedParseInfo.setQueryType(QueryType.METRIC);
@@ -105,6 +106,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay)); expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
@@ -126,6 +128,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.setAggType(NONE); expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
List<String> list = new ArrayList<>(); List<String> list = new ArrayList<>();
list.add("alice"); list.add("alice");
list.add("lucy"); list.add("lucy");
@@ -151,6 +154,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(3, DateConf.DateMode.RECENT, "DAY")); expectedParseInfo.setDateInfo(DataUtils.getDateConf(3, DateConf.DateMode.RECENT, "DAY"));
@@ -171,6 +175,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay)); expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
@@ -197,6 +202,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name", expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); FilterOperatorEnum.EQUALS, "alice", "用户", 2L));

View File

@@ -30,6 +30,7 @@ public class MultiTurnsTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数")); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name", expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); FilterOperatorEnum.EQUALS, "alice", "用户", 2L));

View File

@@ -145,7 +145,7 @@
</dependency> </dependency>
<dependency> <dependency>
<groupId>dev.langchain4j</groupId> <groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId> <artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
<version>${langchain4j.version}</version> <version>${langchain4j.version}</version>
</dependency> </dependency>
</dependencies> </dependencies>