Feature/model data embedding for chat and support status for metric and dimension (#311)

* (improvement)(semantic) add offline status for metric and dimension

* (improvement)(chat) add metric recall

---------

Co-authored-by: jolunoluo
This commit is contained in:
LXW
2023-11-02 18:44:58 +08:00
committed by GitHub
parent f4e3922f47
commit ad20380283
89 changed files with 1572 additions and 896 deletions

View File

@@ -28,6 +28,8 @@ public class SchemaElement implements Serializable {
private String defaultAgg;
private int order;
@Override
public boolean equals(Object o) {
if (this == o) {

View File

@@ -53,6 +53,9 @@ public class SemanticParseInfo {
@Override
public int compare(SchemaElement o1, SchemaElement o2) {
if (o1.getOrder() != o2.getOrder()) {
return o1.getOrder() - o2.getOrder();
}
int len1 = o1.getName().length();
int len2 = o2.getName().length();
if (len1 != len2) {

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.parser.plugin.PluginParser;
import com.tencent.supersonic.chat.plugin.Plugin;
import com.tencent.supersonic.chat.plugin.PluginManager;
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.Comparator;
import java.util.List;

View File

@@ -1,35 +0,0 @@
package com.tencent.supersonic.chat.parser.plugin.embedding;
import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
@Configuration
@Data
public class EmbeddingConfig {
@Value("${embedding.url:}")
private String url;
@Value("${embedding.recognize.path:/preset_query_retrival}")
private String recognizePath;
@Value("${embedding.delete.path:/preset_delete_by_ids}")
private String deletePath;
@Value("${embedding.add.path:/preset_query_add}")
private String addPath;
@Value("${embedding.nResult:1}")
private String nResult;
@Value("${embedding.solvedQuery.recall.path:/solved_query_retrival}")
private String solvedQueryRecallPath;
@Value("${embedding.solvedQuery.add.path:/solved_query_add}")
private String solvedQueryAddPath;
@Value("${embedding.solved.query.nResult:5}")
private String solvedQueryResultNum;
}

View File

@@ -11,7 +11,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.agent.Agent;
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
import com.tencent.supersonic.chat.agent.tool.PluginTool;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;

View File

@@ -0,0 +1,76 @@
package com.tencent.supersonic.chat.responder.parse;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.query.QueryManager;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
import com.tencent.supersonic.common.util.embedding.Retrieval;
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
import org.springframework.util.CollectionUtils;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
public class SimilarMetricParseResponder implements ParseResponder {
@Override
public void fillResponse(ParseResp parseResp, QueryContext queryContext, List<ChatParseDO> chatParseDOS) {
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
return;
}
fillSimilarMetric(parseResp.getSelectedParses().iterator().next());
}
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
if (!QueryManager.isMetricQuery(parseInfo.getQueryMode())
|| CollectionUtils.isEmpty(parseInfo.getMetrics())) {
return;
}
List<String> metricNames = parseInfo.getMetrics().stream()
.map(SchemaElement::getName).collect(Collectors.toList());
Map<String, String> filterCondition = new HashMap<>();
filterCondition.put("modelId", parseInfo.getModelId().toString());
filterCondition.put("type", SchemaElementType.METRIC.name());
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
.filterCondition(filterCondition).queryEmbeddings(null).build();
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, 10);
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return;
}
List<Retrieval> retrievals = retrieveQueryResults.stream()
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream())
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
.distinct().collect(Collectors.toList());
Set<Long> metricIds = parseInfo.getMetrics().stream().map(SchemaElement::getId).collect(Collectors.toSet());
int metricOrder = 0;
for (SchemaElement metric : parseInfo.getMetrics()) {
metric.setOrder(metricOrder++);
}
for (Retrieval retrieval : retrievals) {
if (!metricIds.contains(retrieval.getId())) {
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
SchemaElement.class);
if (retrieval.getMetadata().containsKey("modelId")) {
schemaElement.setModel(Long.parseLong(retrieval.getMetadata().get("modelId")));
}
schemaElement.setOrder(metricOrder++);
parseInfo.getMetrics().add(schemaElement);
}
}
}
}

View File

@@ -18,7 +18,6 @@ import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.persistence.repository.ChatRepository;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Comparator;
@@ -29,7 +28,6 @@ import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.tencent.supersonic.chat.service.ChatService;
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
import com.tencent.supersonic.common.util.JsonUtil;

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.chat.service.impl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
@@ -38,10 +39,12 @@ import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
import com.tencent.supersonic.semantic.model.domain.DimensionService;
import com.tencent.supersonic.semantic.model.domain.MetricService;
import com.tencent.supersonic.semantic.model.domain.pojo.MetaFilter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
@@ -143,16 +146,19 @@ public class ConfigServiceImpl implements ConfigService {
List<Long> filterMetricIdList = blackMetricIdList.stream().distinct().collect(Collectors.toList());
ItemNameVisibilityInfo itemNameVisibility = new ItemNameVisibilityInfo();
MetaFilter metaFilter = new MetaFilter();
metaFilter.setModelIds(Lists.newArrayList(modelId));
if (!CollectionUtils.isEmpty(blackDimIdList)) {
List<DimensionResp> dimensionRespList = dimensionService.getDimensions(modelId);
List<DimensionResp> dimensionRespList = dimensionService.getDimensions(metaFilter);
List<String> blackDimNameList = dimensionRespList.stream().filter(o -> filterDimIdList.contains(o.getId()))
.map(o -> o.getName()).collect(Collectors.toList());
.map(SchemaItem::getName).collect(Collectors.toList());
itemNameVisibility.setBlackDimNameList(blackDimNameList);
}
if (!CollectionUtils.isEmpty(blackMetricIdList)) {
List<MetricResp> metricRespList = metricService.getMetrics(modelId);
List<MetricResp> metricRespList = metricService.getMetrics(metaFilter);
List<String> blackMetricList = metricRespList.stream().filter(o -> filterMetricIdList.contains(o.getId()))
.map(o -> o.getName()).collect(Collectors.toList());
.map(SchemaItem::getName).collect(Collectors.toList());
itemNameVisibility.setBlackMetricNameList(blackMetricList);
}
return itemNameVisibility;

View File

@@ -5,7 +5,7 @@ import com.alibaba.fastjson.serializer.SerializerFeature;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
import lombok.extern.slf4j.Slf4j;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.test.context;
import static org.mockito.ArgumentMatchers.anyList;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.when;
@@ -26,6 +26,9 @@ import com.tencent.supersonic.semantic.model.domain.MetricService;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import com.tencent.supersonic.semantic.model.domain.pojo.DimensionFilter;
import com.tencent.supersonic.semantic.model.domain.pojo.MetaFilter;
import org.mockito.Mockito;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@@ -81,11 +84,11 @@ public class MockBeansConfiguration {
}
public static void dimensionDescBuild(DimensionService dimensionService, List<DimensionResp> dimensionDescs) {
when(dimensionService.getDimensions(anyList())).thenReturn(dimensionDescs);
when(dimensionService.getDimensions(any(DimensionFilter.class))).thenReturn(dimensionDescs);
}
public static void metricDescBuild(MetricService dimensionService, List<MetricResp> metricDescs) {
when(dimensionService.getMetrics(anyList())).thenReturn(metricDescs);
public static void metricDescBuild(MetricService metricService, List<MetricResp> metricDescs) {
when(metricService.getMetrics(any(MetaFilter.class))).thenReturn(metricDescs);
}
public static DimSchemaResp getDimensionDesc(Long id, String bizName, String name) {

View File

@@ -1,27 +0,0 @@
package com.tencent.supersonic.knowledge.listener;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataAddEvent;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.ApplicationListener;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class DataAddListener implements ApplicationListener<DataAddEvent> {
@Override
public void onApplicationEvent(DataAddEvent event) {
DictWord dictWord = new DictWord();
dictWord.setWord(event.getName());
String sign = DictWordType.NATURE_SPILT;
String nature = sign + event.getModelId() + sign + event.getId() + event.getType();
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
dictWord.setNature(nature);
dictWord.setNatureWithFrequency(natureWithFrequency);
log.info("dataAddListener begins to add data:{}", dictWord);
HanlpHelper.addToCustomDictionary(dictWord);
}
}

View File

@@ -1,27 +0,0 @@
package com.tencent.supersonic.knowledge.listener;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataDeleteEvent;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.ApplicationListener;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class DataDeleteListener implements ApplicationListener<DataDeleteEvent> {
@Override
public void onApplicationEvent(DataDeleteEvent event) {
DictWord dictWord = new DictWord();
dictWord.setWord(event.getName());
String sign = DictWordType.NATURE_SPILT;
String nature = sign + event.getModelId() + sign + event.getId() + event.getType();
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
dictWord.setNature(nature);
dictWord.setNatureWithFrequency(natureWithFrequency);
log.info("dataDeleteListener begins to delete data:{}", dictWord);
HanlpHelper.removeFromCustomDictionary(dictWord);
}
}

View File

@@ -1,29 +0,0 @@
package com.tencent.supersonic.knowledge.listener;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataUpdateEvent;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.ApplicationListener;
import org.springframework.stereotype.Component;
@Component
@Slf4j
public class DataUpdateListener implements ApplicationListener<DataUpdateEvent> {
@Override
public void onApplicationEvent(DataUpdateEvent event) {
DictWord dictWord = new DictWord();
dictWord.setWord(event.getName());
String sign = DictWordType.NATURE_SPILT;
String nature = sign + event.getModelId() + sign + event.getId() + event.getType();
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
dictWord.setNature(nature);
dictWord.setNatureWithFrequency(natureWithFrequency);
log.info("dataUpdateListener begins to update data:{}", dictWord);
HanlpHelper.removeFromCustomDictionary(dictWord);
dictWord.setWord(event.getNewName());
HanlpHelper.addToCustomDictionary(dictWord);
}
}

View File

@@ -0,0 +1,43 @@
package com.tencent.supersonic.knowledge.listener;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DataEvent;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.pojo.enums.EventType;
import com.tencent.supersonic.knowledge.dictionary.DictWord;
import com.tencent.supersonic.knowledge.utils.HanlpHelper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.ApplicationListener;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@Component
@Slf4j
public class DictUpdateListener implements ApplicationListener<DataEvent> {
@Override
public void onApplicationEvent(DataEvent dataEvent) {
if (CollectionUtils.isEmpty(dataEvent.getDataItems())) {
return;
}
dataEvent.getDataItems().forEach(dataItem -> {
DictWord dictWord = new DictWord();
dictWord.setWord(dataItem.getName());
String sign = DictWordType.NATURE_SPILT;
String nature = sign + dataItem.getModelId() + sign + dataItem.getId()
+ sign + dataItem.getType().getName();
String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY;
dictWord.setNature(nature);
dictWord.setNatureWithFrequency(natureWithFrequency);
if (EventType.ADD.equals(dataEvent.getEventType())) {
HanlpHelper.addToCustomDictionary(dictWord);
} else if (EventType.DELETE.equals(dataEvent.getEventType())) {
HanlpHelper.removeFromCustomDictionary(dictWord);
} else if (EventType.UPDATE.equals(dataEvent.getEventType())) {
HanlpHelper.removeFromCustomDictionary(dictWord);
dictWord.setWord(dataItem.getNewName());
HanlpHelper.addToCustomDictionary(dictWord);
}
});
}
}