mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 14:36:47 +00:00
[improvement][Headless] Embedding supports Chinese by default and fixes the issue of abnormal number recognition (#726)
This commit is contained in:
@@ -52,7 +52,7 @@ public class OptimizationConfig {
|
|||||||
@Value("${embedding.mapper.round.number:10}")
|
@Value("${embedding.mapper.round.number:10}")
|
||||||
private int embeddingMapperRoundNumber;
|
private int embeddingMapperRoundNumber;
|
||||||
|
|
||||||
@Value("${embedding.mapper.distance.threshold:0.58}")
|
@Value("${embedding.mapper.distance.threshold:0.01}")
|
||||||
private Double embeddingMapperDistanceThreshold;
|
private Double embeddingMapperDistanceThreshold;
|
||||||
|
|
||||||
@Value("${s2SQL.linking.value.switch:true}")
|
@Value("${s2SQL.linking.value.switch:true}")
|
||||||
|
|||||||
@@ -3,12 +3,6 @@ package com.tencent.supersonic.chat.core.mapper;
|
|||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
import com.tencent.supersonic.headless.core.knowledge.helper.NatureHelper;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
|
||||||
import org.apache.commons.lang3.StringUtils;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.springframework.stereotype.Service;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -19,6 +13,11 @@ import java.util.Objects;
|
|||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
@@ -29,7 +28,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
public Map<MatchText, List<T>> match(QueryContext queryContext, List<S2Term> terms,
|
||||||
Set<Long> detectViewIds) {
|
Set<Long> detectViewIds) {
|
||||||
String text = queryContext.getQueryText();
|
String text = queryContext.getQueryText();
|
||||||
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
if (Objects.isNull(terms) || StringUtils.isEmpty(text)) {
|
||||||
return null;
|
return null;
|
||||||
@@ -57,9 +56,9 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
int offset = mapperHelper.getStepOffset(terms, startIndex);
|
||||||
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
index = mapperHelper.getStepIndex(regOffsetToLength, index);
|
||||||
if (index <= text.length()) {
|
if (index <= text.length()) {
|
||||||
String detectSegment = text.substring(startIndex, index);
|
String detectSegment = text.substring(startIndex, index).trim();
|
||||||
detectSegments.add(detectSegment);
|
detectSegments.add(detectSegment);
|
||||||
detectByStep(queryContext, results, detectViewIds, startIndex, index, offset);
|
detectByStep(queryContext, results, detectViewIds, detectSegment, offset);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
startIndex = mapperHelper.getStepIndex(regOffsetToLength, startIndex);
|
||||||
@@ -151,7 +150,7 @@ public abstract class BaseMatchStrategy<T> implements MatchStrategy<T> {
|
|||||||
|
|
||||||
public abstract String getMapKey(T a);
|
public abstract String getMapKey(T a);
|
||||||
|
|
||||||
public abstract void detectByStep(QueryContext queryContext, Set<T> results,
|
public abstract void detectByStep(QueryContext queryContext, Set<T> existResults, Set<Long> detectViewIds,
|
||||||
Set<Long> detectViewIds, Integer startIndex, Integer index, int offset);
|
String detectSegment, int offset);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,15 +55,12 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
|
public void detectByStep(QueryContext queryContext, Set<DatabaseMapResult> existResults, Set<Long> detectViewIds,
|
||||||
Integer startIndex, Integer index, int offset) {
|
String detectSegment, int offset) {
|
||||||
String detectSegment = queryContext.getQueryText().substring(startIndex, index);
|
|
||||||
if (StringUtils.isBlank(detectSegment)) {
|
if (StringUtils.isBlank(detectSegment)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Set<Long> viewIds = mapperHelper.getViewIds(queryContext.getViewId(), queryContext.getAgent());
|
|
||||||
|
|
||||||
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
Double metricDimensionThresholdConfig = getThreshold(queryContext);
|
||||||
|
|
||||||
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
Map<String, Set<SchemaElement>> nameToItems = getNameToItems(allElements);
|
||||||
|
|
||||||
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
for (Entry<String, Set<SchemaElement>> entry : nameToItems.entrySet()) {
|
||||||
@@ -73,9 +70,9 @@ public class DatabaseMatchStrategy extends BaseMatchStrategy<DatabaseMapResult>
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
Set<SchemaElement> schemaElements = entry.getValue();
|
Set<SchemaElement> schemaElements = entry.getValue();
|
||||||
if (!CollectionUtils.isEmpty(viewIds)) {
|
if (!CollectionUtils.isEmpty(detectViewIds)) {
|
||||||
schemaElements = schemaElements.stream()
|
schemaElements = schemaElements.stream()
|
||||||
.filter(schemaElement -> viewIds.contains(schemaElement.getView()))
|
.filter(schemaElement -> detectViewIds.contains(schemaElement.getView()))
|
||||||
.collect(Collectors.toSet());
|
.collect(Collectors.toSet());
|
||||||
}
|
}
|
||||||
for (SchemaElement schemaElement : schemaElements) {
|
for (SchemaElement schemaElement : schemaElements) {
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
package com.tencent.supersonic.chat.core.mapper;
|
package com.tencent.supersonic.chat.core.mapper;
|
||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
|
||||||
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||||
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
import com.tencent.supersonic.headless.core.knowledge.builder.BaseWordBuilder;
|
||||||
@@ -34,14 +34,12 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
//2. build SchemaElementMatch by info
|
//2. build SchemaElementMatch by info
|
||||||
for (EmbeddingResult matchResult : matchResults) {
|
for (EmbeddingResult matchResult : matchResults) {
|
||||||
Long elementId = Retrieval.getLongId(matchResult.getId());
|
Long elementId = Retrieval.getLongId(matchResult.getId());
|
||||||
|
|
||||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
|
||||||
SchemaElement.class);
|
|
||||||
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
|
Long viewId = Retrieval.getLongId(matchResult.getMetadata().get("viewId"));
|
||||||
if (Objects.isNull(viewId)) {
|
if (Objects.isNull(viewId)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
schemaElement = getSchemaElement(viewId, schemaElement.getType(), elementId,
|
SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type"));
|
||||||
|
SchemaElement schemaElement = getSchemaElement(viewId, elementType, elementId,
|
||||||
queryContext.getSemanticSchema());
|
queryContext.getSemanticSchema());
|
||||||
if (schemaElement == null) {
|
if (schemaElement == null) {
|
||||||
continue;
|
continue;
|
||||||
|
|||||||
@@ -2,12 +2,12 @@ package com.tencent.supersonic.chat.core.mapper;
|
|||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
import com.tencent.supersonic.chat.core.config.OptimizationConfig;
|
||||||
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
|
||||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||||
|
import com.tencent.supersonic.headless.core.knowledge.EmbeddingResult;
|
||||||
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
|
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
@@ -47,6 +47,12 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
return a.getName() + Constants.UNDERLINE + a.getId();
|
return a.getName() + Constants.UNDERLINE + a.getId();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
|
||||||
|
String detectSegment, int offset) {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
protected void detectByBatch(QueryContext queryContext, Set<EmbeddingResult> results, Set<Long> detectViewIds,
|
||||||
Set<String> detectSegments) {
|
Set<String> detectSegments) {
|
||||||
@@ -111,9 +117,4 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
selectResultInOneRound(results, oneRoundResults);
|
selectResultInOneRound(results, oneRoundResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<EmbeddingResult> existResults, Set<Long> detectViewIds,
|
|
||||||
Integer startIndex, Integer index, int offset) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,10 +58,7 @@ public class HanlpDictMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||||
Integer startIndex, Integer index, int offset) {
|
String detectSegment, int offset) {
|
||||||
String text = queryContext.getQueryText();
|
|
||||||
String detectSegment = text.substring(startIndex, index);
|
|
||||||
|
|
||||||
// step1. pre search
|
// step1. pre search
|
||||||
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
Integer oneDetectionMaxSize = optimizationConfig.getOneDetectionMaxSize();
|
||||||
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
|
LinkedHashSet<HanlpMapResult> hanlpMapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize,
|
||||||
|
|||||||
@@ -88,9 +88,9 @@ public class SearchMatchStrategy extends BaseMatchStrategy<HanlpMapResult> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> results, Set<Long> detectViewIds,
|
public void detectByStep(QueryContext queryContext, Set<HanlpMapResult> existResults, Set<Long> detectViewIds,
|
||||||
Integer startIndex,
|
String detectSegment, int offset) {
|
||||||
Integer i, int offset) {
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,8 +32,7 @@ public class SchemaDictUpdateListener implements ApplicationListener<DataEvent>
|
|||||||
DictWord dictWord = new DictWord();
|
DictWord dictWord = new DictWord();
|
||||||
dictWord.setWord(dataItem.getName());
|
dictWord.setWord(dataItem.getName());
|
||||||
String sign = DictWordType.NATURE_SPILT;
|
String sign = DictWordType.NATURE_SPILT;
|
||||||
String nature = sign + 1 + sign + dataItem.getId()
|
String nature = sign + 1 + sign + dataItem.getId() + dataItem.getType().name().toLowerCase();
|
||||||
+ sign + dataItem.getType().name().toLowerCase();
|
|
||||||
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
|
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
|
||||||
dictWord.setNature(nature);
|
dictWord.setNature(nature);
|
||||||
dictWord.setNatureWithFrequency(natureWithFrequency);
|
dictWord.setNatureWithFrequency(natureWithFrequency);
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ import lombok.Data;
|
|||||||
@Builder
|
@Builder
|
||||||
public class DataItem {
|
public class DataItem {
|
||||||
|
|
||||||
private Long id;
|
/***
|
||||||
|
* This field uses an underscore (_) at the end.
|
||||||
|
*/
|
||||||
|
private String id;
|
||||||
|
|
||||||
private String bizName;
|
private String bizName;
|
||||||
|
|
||||||
@@ -18,6 +21,9 @@ public class DataItem {
|
|||||||
|
|
||||||
private TypeEnums type;
|
private TypeEnums type;
|
||||||
|
|
||||||
|
/***
|
||||||
|
* This field uses an underscore (_) at the end.
|
||||||
|
*/
|
||||||
private String modelId;
|
private String modelId;
|
||||||
|
|
||||||
private String defaultAgg;
|
private String defaultAgg;
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ public class SysParameter {
|
|||||||
parameters.add(new Parameter("embedding.mapper.number", "5",
|
parameters.add(new Parameter("embedding.mapper.number", "5",
|
||||||
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
|
"批量向量召回文本返回结果个数", "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"));
|
||||||
parameters.add(new Parameter("embedding.mapper.distance.threshold",
|
parameters.add(new Parameter("embedding.mapper.distance.threshold",
|
||||||
"0.58", "向量召回相似度阈值", "相似度大于该阈值的则舍弃", "number", "Mapper相关配置"));
|
"0.01", "向量召回相似度阈值", "相似度大于该阈值的则舍弃", "number", "Mapper相关配置"));
|
||||||
|
|
||||||
//parser config
|
//parser config
|
||||||
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT",
|
Parameter s2SQLParameter = new Parameter("s2SQL.generation", "TWO_PASS_AUTO_COT",
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ public class InMemoryS2EmbeddingStore implements S2EmbeddingStore {
|
|||||||
List<Retrieval> retrievals = new ArrayList<>();
|
List<Retrieval> retrievals = new ArrayList<>();
|
||||||
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
|
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
|
||||||
Retrieval retrieval = new Retrieval();
|
Retrieval retrieval = new Retrieval();
|
||||||
retrieval.setDistance(embeddingMatch.score());
|
retrieval.setDistance(1 - embeddingMatch.score());
|
||||||
retrieval.setId(embeddingMatch.embeddingId());
|
retrieval.setId(embeddingMatch.embeddingId());
|
||||||
retrieval.setQuery(embeddingMatch.embedded().getQuery());
|
retrieval.setQuery(embeddingMatch.embedded().getQuery());
|
||||||
Map<String, Object> metadata = embeddingMatch.embedded().getMetadata();
|
Map<String, Object> metadata = embeddingMatch.embedded().getMetadata();
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.server.listener;
|
|||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
|
||||||
import com.tencent.supersonic.common.pojo.DataEvent;
|
import com.tencent.supersonic.common.pojo.DataEvent;
|
||||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||||
@@ -43,8 +42,7 @@ public class MetaEmbeddingListener implements ApplicationListener<DataEvent> {
|
|||||||
.map(dataItem -> {
|
.map(dataItem -> {
|
||||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||||
embeddingQuery.setQueryId(
|
embeddingQuery.setQueryId(
|
||||||
dataItem.getId().toString() + Constants.UNDERLINE
|
dataItem.getId() + dataItem.getType().name().toLowerCase());
|
||||||
+ dataItem.getType().name().toLowerCase());
|
|
||||||
embeddingQuery.setQuery(dataItem.getName());
|
embeddingQuery.setQuery(dataItem.getName());
|
||||||
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
||||||
embeddingQuery.setMetadata(meta);
|
embeddingQuery.setMetadata(meta);
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ public class DimensionServiceImpl implements DimensionService {
|
|||||||
if (!oldName.equals(dimensionDO.getName())) {
|
if (!oldName.equals(dimensionDO.getName())) {
|
||||||
sendEvent(DataItem.builder().modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
|
sendEvent(DataItem.builder().modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
|
||||||
.newName(dimensionReq.getName()).name(oldName).type(TypeEnums.DIMENSION)
|
.newName(dimensionReq.getName()).name(oldName).type(TypeEnums.DIMENSION)
|
||||||
.id(dimensionDO.getId()).build(), EventType.UPDATE);
|
.id(dimensionDO.getId() + Constants.UNDERLINE).build(), EventType.UPDATE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -366,8 +366,9 @@ public class DimensionServiceImpl implements DimensionService {
|
|||||||
|
|
||||||
private void sendEventBatch(List<DimensionDO> dimensionDOS, EventType eventType) {
|
private void sendEventBatch(List<DimensionDO> dimensionDOS, EventType eventType) {
|
||||||
List<DataItem> dataItems = dimensionDOS.stream()
|
List<DataItem> dataItems = dimensionDOS.stream()
|
||||||
.map(dimensionDO -> DataItem.builder().id(dimensionDO.getId()).name(dimensionDO.getName())
|
.map(dimensionDO -> DataItem.builder().id(dimensionDO.getId() + Constants.UNDERLINE)
|
||||||
.modelId(dimensionDO.getModelId() + Constants.UNDERLINE).type(TypeEnums.DIMENSION).build())
|
.name(dimensionDO.getName()).modelId(dimensionDO.getModelId() + Constants.UNDERLINE)
|
||||||
|
.type(TypeEnums.DIMENSION).build())
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType));
|
eventPublisher.publishEvent(new DataEvent(this, dataItems, eventType));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -449,7 +449,7 @@ public class MetricServiceImpl implements MetricService {
|
|||||||
private DataItem getDataItem(MetricDO metricDO) {
|
private DataItem getDataItem(MetricDO metricDO) {
|
||||||
MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO,
|
MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO,
|
||||||
new HashMap<>(), Lists.newArrayList());
|
new HashMap<>(), Lists.newArrayList());
|
||||||
return DataItem.builder().id(metricDO.getId()).name(metricDO.getName())
|
return DataItem.builder().id(metricDO.getId() + Constants.UNDERLINE).name(metricDO.getName())
|
||||||
.bizName(metricDO.getBizName())
|
.bizName(metricDO.getBizName())
|
||||||
.modelId(metricDO.getModelId() + Constants.UNDERLINE)
|
.modelId(metricDO.getModelId() + Constants.UNDERLINE)
|
||||||
.type(TypeEnums.METRIC).defaultAgg(metricResp.getDefaultAgg()).build();
|
.type(TypeEnums.METRIC).defaultAgg(metricResp.getDefaultAgg()).build();
|
||||||
|
|||||||
@@ -51,7 +51,7 @@
|
|||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
|
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import static dev.langchain4j.exception.IllegalConfigurationException.illegalCon
|
|||||||
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
import static dev.langchain4j.internal.Utils.isNullOrBlank;
|
||||||
|
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
|
|
||||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||||
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
import dev.langchain4j.model.embedding.S2OnnxEmbeddingModel;
|
||||||
|
import dev.langchain4j.model.embedding.BgeSmallZhEmbeddingModel;
|
||||||
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
|
import dev.langchain4j.model.huggingface.HuggingFaceChatModel;
|
||||||
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
|
import dev.langchain4j.model.huggingface.HuggingFaceEmbeddingModel;
|
||||||
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
|
import dev.langchain4j.model.huggingface.HuggingFaceLanguageModel;
|
||||||
@@ -248,7 +248,7 @@ public class S2LangChain4jAutoConfiguration {
|
|||||||
case IN_PROCESS:
|
case IN_PROCESS:
|
||||||
InProcess inProcess = properties.getEmbeddingModel().getInProcess();
|
InProcess inProcess = properties.getEmbeddingModel().getInProcess();
|
||||||
if (Objects.isNull(inProcess) || isNullOrBlank(inProcess.getModelPath())) {
|
if (Objects.isNull(inProcess) || isNullOrBlank(inProcess.getModelPath())) {
|
||||||
return new AllMiniLmL6V2EmbeddingModel();
|
return new BgeSmallZhEmbeddingModel();
|
||||||
}
|
}
|
||||||
return new S2OnnxEmbeddingModel(inProcess.getModelPath(), inProcess.getVocabularyPath());
|
return new S2OnnxEmbeddingModel(inProcess.getModelPath(), inProcess.getVocabularyPath());
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
package com.tencent.supersonic.chat;
|
package com.tencent.supersonic.chat;
|
||||||
|
|
||||||
|
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||||
|
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||||
@@ -8,21 +11,17 @@ import com.tencent.supersonic.chat.core.query.rule.metric.MetricFilterQuery;
|
|||||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricGroupByQuery;
|
import com.tencent.supersonic.chat.core.query.rule.metric.MetricGroupByQuery;
|
||||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery;
|
import com.tencent.supersonic.chat.core.query.rule.metric.MetricModelQuery;
|
||||||
import com.tencent.supersonic.chat.core.query.rule.metric.MetricTopNQuery;
|
import com.tencent.supersonic.chat.core.query.rule.metric.MetricTopNQuery;
|
||||||
import com.tencent.supersonic.util.DataUtils;
|
|
||||||
import com.tencent.supersonic.common.pojo.DateConf;
|
import com.tencent.supersonic.common.pojo.DateConf;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||||
import org.junit.Assert;
|
import com.tencent.supersonic.util.DataUtils;
|
||||||
import org.junit.jupiter.api.Test;
|
|
||||||
|
|
||||||
import java.text.DateFormat;
|
import java.text.DateFormat;
|
||||||
import java.text.SimpleDateFormat;
|
import java.text.SimpleDateFormat;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
import org.junit.Assert;
|
||||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
import org.junit.jupiter.api.Test;
|
||||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
|
|
||||||
|
|
||||||
|
|
||||||
public class MetricTest extends BaseTest {
|
public class MetricTest extends BaseTest {
|
||||||
@@ -41,6 +40,7 @@ public class MetricTest extends BaseTest {
|
|||||||
|
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
||||||
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
|
||||||
|
|
||||||
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
|
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
|
||||||
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
|
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
|
||||||
@@ -76,6 +76,7 @@ public class MetricTest extends BaseTest {
|
|||||||
|
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
||||||
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
|
||||||
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
|
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
|
||||||
expectedParseInfo.setQueryType(QueryType.METRIC);
|
expectedParseInfo.setQueryType(QueryType.METRIC);
|
||||||
|
|
||||||
@@ -105,6 +106,7 @@ public class MetricTest extends BaseTest {
|
|||||||
|
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
||||||
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
|
||||||
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
|
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
|
||||||
|
|
||||||
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
|
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
|
||||||
@@ -126,6 +128,7 @@ public class MetricTest extends BaseTest {
|
|||||||
expectedParseInfo.setAggType(NONE);
|
expectedParseInfo.setAggType(NONE);
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
||||||
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
|
||||||
List<String> list = new ArrayList<>();
|
List<String> list = new ArrayList<>();
|
||||||
list.add("alice");
|
list.add("alice");
|
||||||
list.add("lucy");
|
list.add("lucy");
|
||||||
@@ -151,6 +154,7 @@ public class MetricTest extends BaseTest {
|
|||||||
|
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
||||||
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
|
||||||
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户"));
|
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("用户"));
|
||||||
|
|
||||||
expectedParseInfo.setDateInfo(DataUtils.getDateConf(3, DateConf.DateMode.RECENT, "DAY"));
|
expectedParseInfo.setDateInfo(DataUtils.getDateConf(3, DateConf.DateMode.RECENT, "DAY"));
|
||||||
@@ -171,6 +175,7 @@ public class MetricTest extends BaseTest {
|
|||||||
|
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
||||||
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
|
||||||
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
|
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
|
||||||
|
|
||||||
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
|
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
|
||||||
@@ -197,6 +202,7 @@ public class MetricTest extends BaseTest {
|
|||||||
|
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
||||||
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
|
||||||
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
|
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
|
||||||
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
|
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ public class MultiTurnsTest extends BaseTest {
|
|||||||
|
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("人均访问次数"));
|
||||||
|
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问用户数"));
|
||||||
|
|
||||||
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
|
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
|
||||||
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
|
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
|
||||||
|
|||||||
2
pom.xml
2
pom.xml
@@ -145,7 +145,7 @@
|
|||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>dev.langchain4j</groupId>
|
<groupId>dev.langchain4j</groupId>
|
||||||
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
|
<artifactId>langchain4j-embeddings-bge-small-zh</artifactId>
|
||||||
<version>${langchain4j.version}</version>
|
<version>${langchain4j.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|||||||
Reference in New Issue
Block a user