mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 12:07:42 +00:00
Feature/model data embedding for chat and support status for metric and dimension (#311)
* (improvement)(semantic) add offline status for metric and dimension * (improvement)(chat) add metric recall --------- Co-authored-by: jolunoluo
This commit is contained in:
@@ -28,6 +28,8 @@ public class SchemaElement implements Serializable {
|
||||
|
||||
private String defaultAgg;
|
||||
|
||||
private int order;
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o) {
|
||||
|
||||
@@ -53,6 +53,9 @@ public class SemanticParseInfo {
|
||||
|
||||
@Override
|
||||
public int compare(SchemaElement o1, SchemaElement o2) {
|
||||
if (o1.getOrder() != o2.getOrder()) {
|
||||
return o1.getOrder() - o2.getOrder();
|
||||
}
|
||||
int len1 = o1.getName().length();
|
||||
int len2 = o2.getName().length();
|
||||
if (len1 != len2) {
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
@Data
|
||||
public class EmbeddingConfig {
|
||||
|
||||
@Value("${embedding.url:}")
|
||||
private String url;
|
||||
|
||||
@Value("${embedding.recognize.path:/preset_query_retrival}")
|
||||
private String recognizePath;
|
||||
|
||||
@Value("${embedding.delete.path:/preset_delete_by_ids}")
|
||||
private String deletePath;
|
||||
|
||||
@Value("${embedding.add.path:/preset_query_add}")
|
||||
private String addPath;
|
||||
|
||||
@Value("${embedding.nResult:1}")
|
||||
private String nResult;
|
||||
|
||||
@Value("${embedding.solvedQuery.recall.path:/solved_query_retrival}")
|
||||
private String solvedQueryRecallPath;
|
||||
|
||||
@Value("${embedding.solvedQuery.add.path:/solved_query_add}")
|
||||
private String solvedQueryAddPath;
|
||||
|
||||
@Value("${embedding.solved.query.nResult:5}")
|
||||
private String solvedQueryResultNum;
|
||||
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.PluginTool;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
|
||||
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package com.tencent.supersonic.chat.responder.parse;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class SimilarMetricParseResponder implements ParseResponder {
|
||||
|
||||
|
||||
@Override
|
||||
public void fillResponse(ParseResp parseResp, QueryContext queryContext, List<ChatParseDO> chatParseDOS) {
|
||||
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
||||
return;
|
||||
}
|
||||
fillSimilarMetric(parseResp.getSelectedParses().iterator().next());
|
||||
}
|
||||
|
||||
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
|
||||
if (!QueryManager.isMetricQuery(parseInfo.getQueryMode())
|
||||
|| CollectionUtils.isEmpty(parseInfo.getMetrics())) {
|
||||
return;
|
||||
}
|
||||
List<String> metricNames = parseInfo.getMetrics().stream()
|
||||
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", parseInfo.getModelId().toString());
|
||||
filterCondition.put("type", SchemaElementType.METRIC.name());
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
|
||||
.filterCondition(filterCondition).queryEmbeddings(null).build();
|
||||
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
|
||||
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
|
||||
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, 10);
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
}
|
||||
List<Retrieval> retrievals = retrieveQueryResults.stream()
|
||||
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream())
|
||||
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
|
||||
.distinct().collect(Collectors.toList());
|
||||
Set<Long> metricIds = parseInfo.getMetrics().stream().map(SchemaElement::getId).collect(Collectors.toSet());
|
||||
int metricOrder = 0;
|
||||
for (SchemaElement metric : parseInfo.getMetrics()) {
|
||||
metric.setOrder(metricOrder++);
|
||||
}
|
||||
for (Retrieval retrieval : retrievals) {
|
||||
if (!metricIds.contains(retrieval.getId())) {
|
||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
|
||||
SchemaElement.class);
|
||||
if (retrieval.getMetadata().containsKey("modelId")) {
|
||||
schemaElement.setModel(Long.parseLong(retrieval.getMetadata().get("modelId")));
|
||||
}
|
||||
schemaElement.setOrder(metricOrder++);
|
||||
parseInfo.getMetrics().add(schemaElement);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -18,7 +18,6 @@ import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.persistence.repository.ChatRepository;
|
||||
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
@@ -29,7 +28,6 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.service.impl;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
@@ -38,10 +39,12 @@ import java.util.Objects;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||
import com.tencent.supersonic.semantic.model.domain.DimensionService;
|
||||
import com.tencent.supersonic.semantic.model.domain.MetricService;
|
||||
import com.tencent.supersonic.semantic.model.domain.pojo.MetaFilter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -143,16 +146,19 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
List<Long> filterMetricIdList = blackMetricIdList.stream().distinct().collect(Collectors.toList());
|
||||
|
||||
ItemNameVisibilityInfo itemNameVisibility = new ItemNameVisibilityInfo();
|
||||
MetaFilter metaFilter = new MetaFilter();
|
||||
metaFilter.setModelIds(Lists.newArrayList(modelId));
|
||||
if (!CollectionUtils.isEmpty(blackDimIdList)) {
|
||||
List<DimensionResp> dimensionRespList = dimensionService.getDimensions(modelId);
|
||||
List<DimensionResp> dimensionRespList = dimensionService.getDimensions(metaFilter);
|
||||
List<String> blackDimNameList = dimensionRespList.stream().filter(o -> filterDimIdList.contains(o.getId()))
|
||||
.map(o -> o.getName()).collect(Collectors.toList());
|
||||
.map(SchemaItem::getName).collect(Collectors.toList());
|
||||
itemNameVisibility.setBlackDimNameList(blackDimNameList);
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(blackMetricIdList)) {
|
||||
List<MetricResp> metricRespList = metricService.getMetrics(modelId);
|
||||
|
||||
List<MetricResp> metricRespList = metricService.getMetrics(metaFilter);
|
||||
List<String> blackMetricList = metricRespList.stream().filter(o -> filterMetricIdList.contains(o.getId()))
|
||||
.map(o -> o.getName()).collect(Collectors.toList());
|
||||
.map(SchemaItem::getName).collect(Collectors.toList());
|
||||
itemNameVisibility.setBlackMetricNameList(blackMetricList);
|
||||
}
|
||||
return itemNameVisibility;
|
||||
|
||||
@@ -5,7 +5,7 @@ import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.test.context;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.anyList;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@@ -26,6 +26,9 @@ import com.tencent.supersonic.semantic.model.domain.MetricService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import com.tencent.supersonic.semantic.model.domain.pojo.DimensionFilter;
|
||||
import com.tencent.supersonic.semantic.model.domain.pojo.MetaFilter;
|
||||
import org.mockito.Mockito;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
@@ -81,11 +84,11 @@ public class MockBeansConfiguration {
|
||||
}
|
||||
|
||||
public static void dimensionDescBuild(DimensionService dimensionService, List<DimensionResp> dimensionDescs) {
|
||||
when(dimensionService.getDimensions(anyList())).thenReturn(dimensionDescs);
|
||||
when(dimensionService.getDimensions(any(DimensionFilter.class))).thenReturn(dimensionDescs);
|
||||
}
|
||||
|
||||
public static void metricDescBuild(MetricService dimensionService, List<MetricResp> metricDescs) {
|
||||
when(dimensionService.getMetrics(anyList())).thenReturn(metricDescs);
|
||||
public static void metricDescBuild(MetricService metricService, List<MetricResp> metricDescs) {
|
||||
when(metricService.getMetrics(any(MetaFilter.class))).thenReturn(metricDescs);
|
||||
}
|
||||
|
||||
public static DimSchemaResp getDimensionDesc(Long id, String bizName, String name) {
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
package com.tencent.supersonic.knowledge.listener;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DataAddEvent;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.ApplicationListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class DataAddListener implements ApplicationListener<DataAddEvent> {
|
||||
@Override
|
||||
public void onApplicationEvent(DataAddEvent event) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(event.getName());
|
||||
String sign = DictWordType.NATURE_SPILT;
|
||||
String nature = sign + event.getModelId() + sign + event.getId() + event.getType();
|
||||
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
|
||||
dictWord.setNature(nature);
|
||||
dictWord.setNatureWithFrequency(natureWithFrequency);
|
||||
log.info("dataAddListener begins to add data:{}", dictWord);
|
||||
HanlpHelper.addToCustomDictionary(dictWord);
|
||||
}
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
package com.tencent.supersonic.knowledge.listener;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DataDeleteEvent;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.ApplicationListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class DataDeleteListener implements ApplicationListener<DataDeleteEvent> {
|
||||
@Override
|
||||
public void onApplicationEvent(DataDeleteEvent event) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(event.getName());
|
||||
String sign = DictWordType.NATURE_SPILT;
|
||||
String nature = sign + event.getModelId() + sign + event.getId() + event.getType();
|
||||
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
|
||||
dictWord.setNature(nature);
|
||||
dictWord.setNatureWithFrequency(natureWithFrequency);
|
||||
log.info("dataDeleteListener begins to delete data:{}", dictWord);
|
||||
HanlpHelper.removeFromCustomDictionary(dictWord);
|
||||
}
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
package com.tencent.supersonic.knowledge.listener;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DataUpdateEvent;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.ApplicationListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class DataUpdateListener implements ApplicationListener<DataUpdateEvent> {
|
||||
@Override
|
||||
public void onApplicationEvent(DataUpdateEvent event) {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(event.getName());
|
||||
String sign = DictWordType.NATURE_SPILT;
|
||||
String nature = sign + event.getModelId() + sign + event.getId() + event.getType();
|
||||
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
|
||||
dictWord.setNature(nature);
|
||||
dictWord.setNatureWithFrequency(natureWithFrequency);
|
||||
log.info("dataUpdateListener begins to update data:{}", dictWord);
|
||||
HanlpHelper.removeFromCustomDictionary(dictWord);
|
||||
dictWord.setWord(event.getNewName());
|
||||
HanlpHelper.addToCustomDictionary(dictWord);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
package com.tencent.supersonic.knowledge.listener;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DataEvent;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.common.pojo.enums.EventType;
|
||||
import com.tencent.supersonic.knowledge.dictionary.DictWord;
|
||||
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.ApplicationListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class DictUpdateListener implements ApplicationListener<DataEvent> {
|
||||
|
||||
@Override
|
||||
public void onApplicationEvent(DataEvent dataEvent) {
|
||||
if (CollectionUtils.isEmpty(dataEvent.getDataItems())) {
|
||||
return;
|
||||
}
|
||||
dataEvent.getDataItems().forEach(dataItem -> {
|
||||
DictWord dictWord = new DictWord();
|
||||
dictWord.setWord(dataItem.getName());
|
||||
String sign = DictWordType.NATURE_SPILT;
|
||||
String nature = sign + dataItem.getModelId() + sign + dataItem.getId()
|
||||
+ sign + dataItem.getType().getName();
|
||||
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
|
||||
dictWord.setNature(nature);
|
||||
dictWord.setNatureWithFrequency(natureWithFrequency);
|
||||
if (EventType.ADD.equals(dataEvent.getEventType())) {
|
||||
HanlpHelper.addToCustomDictionary(dictWord);
|
||||
} else if (EventType.DELETE.equals(dataEvent.getEventType())) {
|
||||
HanlpHelper.removeFromCustomDictionary(dictWord);
|
||||
} else if (EventType.UPDATE.equals(dataEvent.getEventType())) {
|
||||
HanlpHelper.removeFromCustomDictionary(dictWord);
|
||||
dictWord.setWord(dataItem.getNewName());
|
||||
HanlpHelper.addToCustomDictionary(dictWord);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user