(improvement)(chat) Vector retrieval supports filtering by modelId collection during query. (#1576)

This commit is contained in:
lexluo09
2024-08-16 21:31:07 +08:00
committed by GitHub
parent 6aff51d394
commit 7150f19def
4 changed files with 73 additions and 66 deletions

View File

@@ -1,20 +1,19 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
import java.util.Objects;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
@@ -23,6 +22,7 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
@@ -45,7 +45,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
return;
}
List<String> metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
Map<String, String> filterCondition = new HashMap<>();
Map<String, Object> filterCondition = new HashMap<>();
filterCondition.put("modelId", parseInfo.getMetrics().iterator().next().getDataSetId().toString());
filterCondition.put("type", SchemaElementType.METRIC.name());
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)