mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 22:08:56 +00:00
Feature/model data embedding for chat and support status for metric and dimension (#311)
* (improvement)(semantic) add offline status for metric and dimension * (improvement)(chat) add metric recall --------- Co-authored-by: jolunoluo
This commit is contained in:
@@ -7,6 +7,7 @@ import com.tencent.supersonic.chat.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
package com.tencent.supersonic.chat.parser.plugin.embedding;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
@Data
|
||||
public class EmbeddingConfig {
|
||||
|
||||
@Value("${embedding.url:}")
|
||||
private String url;
|
||||
|
||||
@Value("${embedding.recognize.path:/preset_query_retrival}")
|
||||
private String recognizePath;
|
||||
|
||||
@Value("${embedding.delete.path:/preset_delete_by_ids}")
|
||||
private String deletePath;
|
||||
|
||||
@Value("${embedding.add.path:/preset_query_add}")
|
||||
private String addPath;
|
||||
|
||||
@Value("${embedding.nResult:1}")
|
||||
private String nResult;
|
||||
|
||||
@Value("${embedding.solvedQuery.recall.path:/solved_query_retrival}")
|
||||
private String solvedQueryRecallPath;
|
||||
|
||||
@Value("${embedding.solvedQuery.add.path:/solved_query_add}")
|
||||
private String solvedQueryAddPath;
|
||||
|
||||
@Value("${embedding.solved.query.nResult:5}")
|
||||
private String solvedQueryResultNum;
|
||||
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.agent.Agent;
|
||||
import com.tencent.supersonic.chat.agent.tool.AgentToolType;
|
||||
import com.tencent.supersonic.chat.agent.tool.PluginTool;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
|
||||
import com.tencent.supersonic.chat.plugin.event.PluginAddEvent;
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
package com.tencent.supersonic.chat.responder.parse;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.api.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.semantic.model.domain.listener.MetaEmbeddingListener;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class SimilarMetricParseResponder implements ParseResponder {
|
||||
|
||||
|
||||
@Override
|
||||
public void fillResponse(ParseResp parseResp, QueryContext queryContext, List<ChatParseDO> chatParseDOS) {
|
||||
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
||||
return;
|
||||
}
|
||||
fillSimilarMetric(parseResp.getSelectedParses().iterator().next());
|
||||
}
|
||||
|
||||
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
|
||||
if (!QueryManager.isMetricQuery(parseInfo.getQueryMode())
|
||||
|| CollectionUtils.isEmpty(parseInfo.getMetrics())) {
|
||||
return;
|
||||
}
|
||||
List<String> metricNames = parseInfo.getMetrics().stream()
|
||||
.map(SchemaElement::getName).collect(Collectors.toList());
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
filterCondition.put("modelId", parseInfo.getModelId().toString());
|
||||
filterCondition.put("type", SchemaElementType.METRIC.name());
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
|
||||
.filterCondition(filterCondition).queryEmbeddings(null).build();
|
||||
EmbeddingUtils embeddingUtils = ContextUtils.getBean(EmbeddingUtils.class);
|
||||
List<RetrieveQueryResult> retrieveQueryResults = embeddingUtils.retrieveQuery(
|
||||
MetaEmbeddingListener.COLLECTION_NAME, retrieveQuery, 10);
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
}
|
||||
List<Retrieval> retrievals = retrieveQueryResults.stream()
|
||||
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream())
|
||||
.sorted(Comparator.comparingDouble(Retrieval::getDistance).reversed())
|
||||
.distinct().collect(Collectors.toList());
|
||||
Set<Long> metricIds = parseInfo.getMetrics().stream().map(SchemaElement::getId).collect(Collectors.toSet());
|
||||
int metricOrder = 0;
|
||||
for (SchemaElement metric : parseInfo.getMetrics()) {
|
||||
metric.setOrder(metricOrder++);
|
||||
}
|
||||
for (Retrieval retrieval : retrievals) {
|
||||
if (!metricIds.contains(retrieval.getId())) {
|
||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
|
||||
SchemaElement.class);
|
||||
if (retrieval.getMetadata().containsKey("modelId")) {
|
||||
schemaElement.setModel(Long.parseLong(retrieval.getMetadata().get("modelId")));
|
||||
}
|
||||
schemaElement.setOrder(metricOrder++);
|
||||
parseInfo.getMetrics().add(schemaElement);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -18,7 +18,6 @@ import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.chat.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.persistence.repository.ChatRepository;
|
||||
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
@@ -29,7 +28,6 @@ import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import com.tencent.supersonic.chat.service.ChatService;
|
||||
import com.tencent.supersonic.chat.utils.SolvedQueryManager;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.service.impl;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.component.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.pojo.ModelSchema;
|
||||
@@ -38,10 +39,12 @@ import java.util.Objects;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.semantic.api.model.response.DimensionResp;
|
||||
import com.tencent.supersonic.semantic.api.model.response.MetricResp;
|
||||
import com.tencent.supersonic.semantic.model.domain.DimensionService;
|
||||
import com.tencent.supersonic.semantic.model.domain.MetricService;
|
||||
import com.tencent.supersonic.semantic.model.domain.pojo.MetaFilter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -143,16 +146,19 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
List<Long> filterMetricIdList = blackMetricIdList.stream().distinct().collect(Collectors.toList());
|
||||
|
||||
ItemNameVisibilityInfo itemNameVisibility = new ItemNameVisibilityInfo();
|
||||
MetaFilter metaFilter = new MetaFilter();
|
||||
metaFilter.setModelIds(Lists.newArrayList(modelId));
|
||||
if (!CollectionUtils.isEmpty(blackDimIdList)) {
|
||||
List<DimensionResp> dimensionRespList = dimensionService.getDimensions(modelId);
|
||||
List<DimensionResp> dimensionRespList = dimensionService.getDimensions(metaFilter);
|
||||
List<String> blackDimNameList = dimensionRespList.stream().filter(o -> filterDimIdList.contains(o.getId()))
|
||||
.map(o -> o.getName()).collect(Collectors.toList());
|
||||
.map(SchemaItem::getName).collect(Collectors.toList());
|
||||
itemNameVisibility.setBlackDimNameList(blackDimNameList);
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(blackMetricIdList)) {
|
||||
List<MetricResp> metricRespList = metricService.getMetrics(modelId);
|
||||
|
||||
List<MetricResp> metricRespList = metricService.getMetrics(metaFilter);
|
||||
List<String> blackMetricList = metricRespList.stream().filter(o -> filterMetricIdList.contains(o.getId()))
|
||||
.map(o -> o.getName()).collect(Collectors.toList());
|
||||
.map(SchemaItem::getName).collect(Collectors.toList());
|
||||
itemNameVisibility.setBlackMetricNameList(blackMetricList);
|
||||
}
|
||||
return itemNameVisibility;
|
||||
|
||||
@@ -5,7 +5,7 @@ import com.alibaba.fastjson.serializer.SerializerFeature;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.SolvedQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.EmbeddingResp;
|
||||
import com.tencent.supersonic.chat.parser.plugin.embedding.RecallRetrieval;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.test.context;
|
||||
|
||||
import static org.mockito.ArgumentMatchers.anyList;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.ArgumentMatchers.anyLong;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@@ -26,6 +26,9 @@ import com.tencent.supersonic.semantic.model.domain.MetricService;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import com.tencent.supersonic.semantic.model.domain.pojo.DimensionFilter;
|
||||
import com.tencent.supersonic.semantic.model.domain.pojo.MetaFilter;
|
||||
import org.mockito.Mockito;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
@@ -81,11 +84,11 @@ public class MockBeansConfiguration {
|
||||
}
|
||||
|
||||
public static void dimensionDescBuild(DimensionService dimensionService, List<DimensionResp> dimensionDescs) {
|
||||
when(dimensionService.getDimensions(anyList())).thenReturn(dimensionDescs);
|
||||
when(dimensionService.getDimensions(any(DimensionFilter.class))).thenReturn(dimensionDescs);
|
||||
}
|
||||
|
||||
public static void metricDescBuild(MetricService dimensionService, List<MetricResp> metricDescs) {
|
||||
when(dimensionService.getMetrics(anyList())).thenReturn(metricDescs);
|
||||
public static void metricDescBuild(MetricService metricService, List<MetricResp> metricDescs) {
|
||||
when(metricService.getMetrics(any(MetaFilter.class))).thenReturn(metricDescs);
|
||||
}
|
||||
|
||||
public static DimSchemaResp getDimensionDesc(Long id, String bizName, String name) {
|
||||
|
||||
Reference in New Issue
Block a user