[improvement][chat] In STRICT mode, embedingMapper does not perform mapping (#1858)

This commit is contained in:
lexluo09
2024-10-29 20:54:37 +08:00
committed by GitHub
parent cbb76550c7
commit d9cf874536
5 changed files with 10 additions and 13 deletions

View File

@@ -88,7 +88,8 @@ public class NL2SQLParser implements ChatQueryParser {
for (Long datasetId : requestedDatasets) {
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
for (MapModeEnum mode : Lists.newArrayList(MapModeEnum.STRICT, MapModeEnum.MODERATE)) {
for (MapModeEnum mode : Lists.newArrayList(MapModeEnum.STRICT,
MapModeEnum.MODERATE)) {
queryNLReq.setMapModeEnum(mode);
doParse(queryNLReq, parseResp);
}

View File

@@ -27,7 +27,6 @@ public abstract class BaseMapper implements SchemaMapper {
@Override
public void map(ChatQueryContext chatQueryContext) {
String simpleName = this.getClass().getSimpleName();
long startTime = System.currentTimeMillis();
log.debug("before {},mapInfo:{}", simpleName,

View File

@@ -4,6 +4,7 @@ import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
@@ -14,13 +15,18 @@ import lombok.extern.slf4j.Slf4j;
import java.util.List;
import java.util.Objects;
/** A mapper that recognizes schema elements with vector embedding. */
/**
* A mapper that recognizes schema elements with vector embedding.
*/
@Slf4j
public class EmbeddingMapper extends BaseMapper {
@Override
public void doMap(ChatQueryContext chatQueryContext) {
// 1. query from embedding by queryText
if (MapModeEnum.STRICT.equals(chatQueryContext.getRequest().getMapModeEnum())) {
return;
}
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);

View File

@@ -25,7 +25,6 @@ import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER;
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER;
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_THRESHOLD;
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_THRESHOLD_MIN;
/**
* EmbeddingMatchStrategy uses vector database to perform similarity search against the embeddings
@@ -64,12 +63,8 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
List<String> queryTextsSub, ChatQueryContext chatQueryContext) {
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
double embeddingThreshold =
double threshold =
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
double embeddingThresholdMin =
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN));
double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin,
chatQueryContext.getRequest().getMapModeEnum());
// step1. build query params
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();

View File

@@ -54,10 +54,6 @@ public class MapperConfig extends ParameterConfig {
new Parameter("s2.mapper.embedding.threshold", "0.98", "向量召回相似度阈值", "相似度小于该阈值的则舍弃",
"number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_THRESHOLD_MIN =
new Parameter("s2.mapper.embedding.min.threshold", "0.9", "向量召回最小相似度阈值",
"向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
new Parameter("s2.mapper.embedding.round.number", "10", "向量召回最小相似度阈值",
"向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");