5 Commits

Author SHA1 Message Date
mislayming
58e41cd4bc Merge aaf2d46a56 into 978ae53fb3 2025-02-18 19:31:55 +08:00
williamhliu
978ae53fb3 (feature)(supersonic-fe) remove upload file (#2068)
Some checks are pending
supersonic CentOS CI / build (21) (push) Waiting to run
supersonic mac CI / build (21) (push) Waiting to run
supersonic ubuntu CI / build (21) (push) Waiting to run
supersonic windows CI / build (21) (push) Waiting to run
2025-02-18 19:28:52 +08:00
williamhliu
e04bc3cce8 (feature)(chat-sdk) add unit (#2067) 2025-02-18 19:21:13 +08:00
wua.ming
aaf2d46a56 (improvement)(chat) Enhancing the capability of embedding with LLM-based secondary judgment. 2025-02-18 15:57:39 +08:00
jerryjzhang
c8abea9c1a (improvement)(project)Introduce aibi-env.sh script to simplify user settings.
(improvement)(project)Introduce aibi-env.sh script to simplify user settings.
2025-02-18 15:50:51 +08:00
16 changed files with 16216 additions and 12847 deletions

View File

@@ -1,10 +1,6 @@
package com.tencent.supersonic.headless.api.pojo; package com.tencent.supersonic.headless.api.pojo;
import lombok.AllArgsConstructor; import lombok.*;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import java.io.Serializable; import java.io.Serializable;
@@ -21,6 +17,7 @@ public class SchemaElementMatch implements Serializable {
private String word; private String word;
private Long frequency; private Long frequency;
private boolean isInherited; private boolean isInherited;
private boolean llmMatched;
public boolean isFullMatched() { public boolean isFullMatched() {
return 1.0 == similarity; return 1.0 == similarity;

View File

@@ -13,6 +13,7 @@ public class EmbeddingResult extends MapResult {
private String id; private String id;
private Map<String, String> metadata; private Map<String, String> metadata;
private boolean llmMatched;
@Override @Override
public boolean equals(Object o) { public boolean equals(Object o) {

View File

@@ -1,9 +1,12 @@
package com.tencent.supersonic.headless.chat.mapper; package com.tencent.supersonic.headless.chat.mapper;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
@@ -11,6 +14,7 @@ import com.tencent.supersonic.headless.chat.knowledge.builder.BaseWordBuilder;
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper; import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.Retrieval;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
@@ -23,10 +27,16 @@ public class EmbeddingMapper extends BaseMapper {
@Override @Override
public boolean accept(ChatQueryContext chatQueryContext) { public boolean accept(ChatQueryContext chatQueryContext) {
return MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum()); boolean b0 = MapModeEnum.LOOSE.equals(chatQueryContext.getRequest().getMapModeEnum());
boolean b1 = chatQueryContext.getRequest().getText2SQLType() == Text2SQLType.LLM_OR_RULE;
return b0 || b1;
} }
public void doMap(ChatQueryContext chatQueryContext) { public void doMap(ChatQueryContext chatQueryContext) {
// TODO: 如果是在LOOSE执行过了那么在LLM_OR_RULE阶段可以不用执行所以这里缺乏一个状态来传递暂时先忽略这个浪费行为吧
SchemaMapInfo mappedInfo = chatQueryContext.getMapInfo();
// 1. Query from embedding by queryText // 1. Query from embedding by queryText
EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class); EmbeddingMatchStrategy matchStrategy = ContextUtils.getBean(EmbeddingMatchStrategy.class);
List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy); List<EmbeddingResult> matchResults = getMatches(chatQueryContext, matchStrategy);
@@ -53,15 +63,26 @@ public class EmbeddingMapper extends BaseMapper {
continue; continue;
} }
// Build SchemaElementMatch object // Build SchemaElementMatch object
SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder()
.element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY) .element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY)
.word(matchResult.getName()).similarity(matchResult.getSimilarity()) .word(matchResult.getName()).similarity(matchResult.getSimilarity())
.detectWord(matchResult.getDetectWord()).build(); .detectWord(matchResult.getDetectWord()).build();
schemaElementMatch.setLlmMatched(matchResult.isLlmMatched());
// 3. Add SchemaElementMatch to mapInfo // 3. Add SchemaElementMatch to mapInfo
addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch); addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch);
} }
if (CollectionUtils.isEmpty(matchResults)) {
log.info("embedding mapper no match");
} else {
for (EmbeddingResult matchResult : matchResults) {
log.info("embedding match name=[{}],detectWord=[{}],similarity=[{}],metadata=[{}]",
matchResult.getName(), matchResult.getDetectWord(),
matchResult.getSimilarity(), JsonUtil.toString(matchResult.getMetadata()));
}
}
} }
} }

View File

@@ -1,9 +1,17 @@
package com.tencent.supersonic.headless.chat.mapper; package com.tencent.supersonic.headless.chat.mapper;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.hankcs.hanlp.seg.common.Term;
import com.tencent.supersonic.headless.api.pojo.response.S2Term;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult; import com.tencent.supersonic.headless.chat.knowledge.EmbeddingResult;
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService; import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.provider.ModelProvider;
import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult; import dev.langchain4j.store.embedding.RetrieveQueryResult;
@@ -14,18 +22,12 @@ import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.util.ArrayList; import java.util.*;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_NUMBER; import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.*;
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_ROUND_NUMBER;
import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING_MAPPER_THRESHOLD;
/** /**
* EmbeddingMatchStrategy uses vector database to perform similarity search against the embeddings * EmbeddingMatchStrategy uses vector database to perform similarity search against the embeddings
@@ -35,37 +37,165 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING
@Slf4j @Slf4j
public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult> { public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult> {
@Autowired
protected MapperConfig mapperConfig;
@Autowired @Autowired
private MetaEmbeddingService metaEmbeddingService; private MetaEmbeddingService metaEmbeddingService;
private static final String LLM_FILTER_PROMPT =
"""
\
#Role: You are a professional data analyst specializing in metrics and dimensions.
#Task: Given a user query and a list of retrieved metrics/dimensions through vector recall,
please analyze which metrics/dimensions the user is most likely interested in.
#Rules:
1. Based on user query and retrieved info, accurately determine metrics/dimensions user truly cares about.
2. Do not return all retrieved info, only select those highly relevant to user query.
3. Maintain high quality output, exclude metrics/dimensions irrelevant to user intent.
4. Output must be in JSON array format, only include IDs from retrieved info, e.g.: ['id1', 'id2']
5. Return JSON content directly without markdown formatting
#Input Example:
#User Query: {{userText}}
#Retrieved Metrics/Dimensions: {{retrievedInfo}}
#Output:""";
@Override
public List<EmbeddingResult> detect(ChatQueryContext chatQueryContext, List<S2Term> terms,
Set<Long> detectDataSetIds) {
if (chatQueryContext == null || CollectionUtils.isEmpty(detectDataSetIds)) {
log.warn("Invalid input parameters: context={}, dataSetIds={}", chatQueryContext,
detectDataSetIds);
return Collections.emptyList();
}
// 1. Base detection
List<EmbeddingResult> baseResults = super.detect(chatQueryContext, terms, detectDataSetIds);
boolean useLLM = Boolean.parseBoolean(mapperConfig.getParameterValue(EMBEDDING_MAPPER_USE_LLM));
// 2. LLM enhanced detection
if (useLLM) {
List<EmbeddingResult> llmResults = detectWithLLM(chatQueryContext, detectDataSetIds);
if (!CollectionUtils.isEmpty(llmResults)) {
baseResults.addAll(llmResults);
}
}
// 3. Deduplicate results
return baseResults.stream().distinct().collect(Collectors.toList());
}
/**
* Perform enhanced detection using LLM
*/
private List<EmbeddingResult> detectWithLLM(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds) {
try {
String queryText = chatQueryContext.getRequest().getQueryText();
if (StringUtils.isBlank(queryText)) {
return Collections.emptyList();
}
// Get segmentation results
Set<String> detectSegments = extractValidSegments(queryText);
if (CollectionUtils.isEmpty(detectSegments)) {
log.info("No valid segments found for text: {}", queryText);
return Collections.emptyList();
}
return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments, true);
} catch (Exception e) {
log.error("Error in LLM detection for context: {}", chatQueryContext, e);
return Collections.emptyList();
}
}
/**
* Extract valid word segments by filtering out unwanted word natures
*/
private Set<String> extractValidSegments(String text) {
List<String> natureList = Arrays.asList(StringUtils.split(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE ), ","));
return HanlpHelper.getSegment().seg(text).stream()
.filter(t -> natureList.stream().noneMatch(nature -> t.nature.startsWith(nature)))
.map(Term::getWord).collect(Collectors.toSet());
}
@Override @Override
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext, public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, Set<String> detectSegments) { Set<Long> detectDataSetIds, Set<String> detectSegments) {
return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments, false);
}
/**
* Process detection in batches with LLM option
*
* @param chatQueryContext The context of the chat query
* @param detectDataSetIds Target dataset IDs for detection
* @param detectSegments Segments to be detected
* @param useLlm Whether to use LLM for filtering results
* @return List of embedding results
*/
public List<EmbeddingResult> detectByBatch(ChatQueryContext chatQueryContext,
Set<Long> detectDataSetIds, Set<String> detectSegments, boolean useLlm) {
Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet(); Set<EmbeddingResult> results = ConcurrentHashMap.newKeySet();
int embeddingMapperBatch = Integer int embeddingMapperBatch = Integer
.valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); .valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH));
List<String> queryTextsList = // Process and filter query texts
detectSegments.stream().map(detectSegment -> detectSegment.trim()) List<String> queryTextsList = detectSegments.stream().map(String::trim)
.filter(detectSegment -> StringUtils.isNotBlank(detectSegment)) .filter(StringUtils::isNotBlank).collect(Collectors.toList());
.collect(Collectors.toList());
// Partition queries into sub-lists for batch processing
List<List<String>> queryTextsSubList = List<List<String>> queryTextsSubList =
Lists.partition(queryTextsList, embeddingMapperBatch); Lists.partition(queryTextsList, embeddingMapperBatch);
// Create and execute tasks for each batch
List<Callable<Void>> tasks = new ArrayList<>(); List<Callable<Void>> tasks = new ArrayList<>();
for (List<String> queryTextsSub : queryTextsSubList) { for (List<String> queryTextsSub : queryTextsSubList) {
tasks.add(createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results)); tasks.add(
createTask(chatQueryContext, detectDataSetIds, queryTextsSub, results, useLlm));
} }
executeTasks(tasks); executeTasks(tasks);
// Apply LLM filtering if enabled
if (useLlm) {
Map<String, Object> variable = new HashMap<>();
variable.put("userText", chatQueryContext.getRequest().getQueryText());
variable.put("retrievedInfo", JSONObject.toJSONString(results));
Prompt prompt = PromptTemplate.from(LLM_FILTER_PROMPT).apply(variable);
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel();
String response = chatLanguageModel.generate(prompt.toUserMessage().singleText());
if (StringUtils.isBlank(response)) {
results.clear();
} else {
List<String> retrievedIds = JSONObject.parseArray(response, String.class);
results = results.stream().filter(t -> retrievedIds.contains(t.getId()))
.collect(Collectors.toSet());
results.forEach(r -> r.setLlmMatched(true));
}
}
return new ArrayList<>(results); return new ArrayList<>(results);
} }
/**
* Create a task for batch processing
*
* @param chatQueryContext The context of the chat query
* @param detectDataSetIds Target dataset IDs
* @param queryTextsSub Sub-list of query texts to process
* @param results Shared result set for collecting results
* @param useLlm Whether to use LLM
* @return Callable task
*/
private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds, private Callable<Void> createTask(ChatQueryContext chatQueryContext, Set<Long> detectDataSetIds,
List<String> queryTextsSub, Set<EmbeddingResult> results) { List<String> queryTextsSub, Set<EmbeddingResult> results, boolean useLlm) {
return () -> { return () -> {
List<EmbeddingResult> oneRoundResults = List<EmbeddingResult> oneRoundResults = detectByQueryTextsSub(detectDataSetIds,
detectByQueryTextsSub(detectDataSetIds, queryTextsSub, chatQueryContext); queryTextsSub, chatQueryContext, useLlm);
synchronized (results) { synchronized (results) {
selectResultInOneRound(results, oneRoundResults); selectResultInOneRound(results, oneRoundResults);
} }
@@ -73,57 +203,73 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy<EmbeddingResult>
}; };
} }
/**
* Process a sub-list of query texts
*
* @param detectDataSetIds Target dataset IDs
* @param queryTextsSub Sub-list of query texts
* @param chatQueryContext Chat query context
* @param useLlm Whether to use LLM
* @return List of embedding results for this batch
*/
private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds, private List<EmbeddingResult> detectByQueryTextsSub(Set<Long> detectDataSetIds,
List<String> queryTextsSub, ChatQueryContext chatQueryContext) { List<String> queryTextsSub, ChatQueryContext chatQueryContext, boolean useLlm) {
Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds(); Map<Long, List<Long>> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds();
// Get configuration parameters
double threshold = double threshold =
Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD)); Double.parseDouble(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD));
// step1. build query params
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
// step2. retrieveQuery by detectSegment
int embeddingNumber = int embeddingNumber =
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER)); Integer.parseInt(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER));
int embeddingRoundNumber =
Integer.parseInt(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER));
// Build and execute query
RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build();
List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery( List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds); retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds);
if (CollectionUtils.isEmpty(retrieveQueryResults)) { if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return new ArrayList<>(); return Collections.emptyList();
} }
// step3. build EmbeddingResults
List<EmbeddingResult> collect = retrieveQueryResults.stream().map(retrieveQueryResult -> { // Process results
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval(); List<EmbeddingResult> collect = retrieveQueryResults.stream().peek(result -> {
if (CollectionUtils.isNotEmpty(retrievals)) { if (!useLlm && CollectionUtils.isNotEmpty(result.getRetrieval())) {
retrievals.removeIf(retrieval -> { result.getRetrieval()
if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) { .removeIf(retrieval -> !result.getQuery().contains(retrieval.getQuery())
return retrieval.getSimilarity() < threshold; && retrieval.getSimilarity() < threshold);
}
return false;
});
} }
return retrieveQueryResult; }).filter(result -> CollectionUtils.isNotEmpty(result.getRetrieval()))
}).filter(retrieveQueryResult -> CollectionUtils .flatMap(result -> result.getRetrieval().stream()
.isNotEmpty(retrieveQueryResult.getRetrieval())) .map(retrieval -> convertToEmbeddingResult(result, retrieval)))
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()
.map(retrieval -> {
EmbeddingResult embeddingResult = new EmbeddingResult();
BeanUtils.copyProperties(retrieval, embeddingResult);
embeddingResult.setDetectWord(retrieveQueryResult.getQuery());
embeddingResult.setName(retrieval.getQuery());
Map<String, String> convertedMap = retrieval.getMetadata().entrySet()
.stream().collect(Collectors.toMap(Map.Entry::getKey,
entry -> entry.getValue().toString()));
embeddingResult.setMetadata(convertedMap);
return embeddingResult;
}))
.collect(Collectors.toList()); .collect(Collectors.toList());
// step4. select mapResul in one round // Sort and limit results
int embeddingRoundNumber = return collect.stream()
Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER)); .sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity).reversed())
int roundNumber = embeddingRoundNumber * queryTextsSub.size(); .limit(embeddingRoundNumber * queryTextsSub.size()).collect(Collectors.toList());
return collect.stream().sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity)) }
.limit(roundNumber).collect(Collectors.toList());
/**
* Convert RetrieveQueryResult and Retrieval to EmbeddingResult
*
* @param queryResult The query result containing retrieval information
* @param retrieval The retrieval data to be converted
* @return Converted EmbeddingResult
*/
private EmbeddingResult convertToEmbeddingResult(RetrieveQueryResult queryResult,
Retrieval retrieval) {
EmbeddingResult embeddingResult = new EmbeddingResult();
BeanUtils.copyProperties(retrieval, embeddingResult);
embeddingResult.setDetectWord(queryResult.getQuery());
embeddingResult.setName(retrieval.getQuery());
// Convert metadata to string values
Map<String, String> metadata = retrieval.getMetadata().entrySet().stream().collect(
Collectors.toMap(Map.Entry::getKey, entry -> String.valueOf(entry.getValue())));
embeddingResult.setMetadata(metadata);
return embeddingResult;
} }
} }

View File

@@ -7,12 +7,7 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.*;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate; import java.util.function.Predicate;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -66,7 +61,7 @@ public class MapFilter {
List<SchemaElementMatch> value = entry.getValue(); List<SchemaElementMatch> value = entry.getValue();
if (!CollectionUtils.isEmpty(value)) { if (!CollectionUtils.isEmpty(value)) {
value.removeIf(schemaElementMatch -> StringUtils value.removeIf(schemaElementMatch -> StringUtils
.length(schemaElementMatch.getDetectWord()) <= 1); .length(schemaElementMatch.getDetectWord()) <= 1 && !schemaElementMatch.isLlmMatched());
} }
} }
} }
@@ -85,7 +80,7 @@ public class MapFilter {
} }
public static void filterByQueryDataType(ChatQueryContext chatQueryContext, public static void filterByQueryDataType(ChatQueryContext chatQueryContext,
Predicate<SchemaElement> needRemovePredicate) { Predicate<SchemaElement> needRemovePredicate) {
Map<Long, List<SchemaElementMatch>> dataSetElementMatches = Map<Long, List<SchemaElementMatch>> dataSetElementMatches =
chatQueryContext.getMapInfo().getDataSetElementMatches(); chatQueryContext.getMapInfo().getDataSetElementMatches();
for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) { for (Map.Entry<Long, List<SchemaElementMatch>> entry : dataSetElementMatches.entrySet()) {

View File

@@ -57,4 +57,12 @@ public class MapperConfig extends ParameterConfig {
public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER = public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER =
new Parameter("s2.mapper.embedding.round.number", "10", "向量召回最小相似度阈值", new Parameter("s2.mapper.embedding.round.number", "10", "向量召回最小相似度阈值",
"向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"); "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_USE_LLM =
new Parameter("s2.mapper.embedding.use-llm-enhance", "false", "使用LLM对召回的向量进行二次判断开关",
"embedding的结果再通过一次LLM来筛选这时候忽略各个向量阀值", "bool", "Mapper相关配置");
public static final Parameter EMBEDDING_MAPPER_ALLOWED_SEGMENT_NATURE =
new Parameter("s2.mapper.embedding.allowed-segment-nature", "['v', 'd', 'a']", "使用LLM召回二次处理时对问题分词词性的控制",
"分词后允许的词性才会进行向量召回", "list", "Mapper相关配置");
} }

View File

@@ -1,14 +1,14 @@
spring: spring:
datasource: datasource:
driver-class-name: org.postgresql.Driver driver-class-name: org.postgresql.Driver
url: jdbc:postgresql://${S2_DB_HOST:localhost}:${S2_DB_PORT:5432}/${S2_DB_DATABASE:postgres}?stringtype=unspecified url: jdbc:postgresql://localhost:5432/postgres?stringtype=unspecified
username: ${S2_DB_USER:postgres} username: postgres
password: ${S2_DB_PASSWORD:postgres} password: postgres
sql: sql:
init: init:
mode: always mode: always
username: ${S2_DB_USER:postgres} username: postgres
password: ${S2_DB_PASSWORD:postgres} password: postgres
schema-locations: classpath:db/schema-postgres.sql,classpath:db/schema-postgres-demo.sql schema-locations: classpath:db/schema-postgres.sql,classpath:db/schema-postgres-demo.sql
data-locations: classpath:db/data-postgres.sql,classpath:db/data-postgres-demo.sql data-locations: classpath:db/data-postgres.sql,classpath:db/data-postgres-demo.sql
@@ -17,9 +17,9 @@ s2:
store: store:
provider: PGVECTOR provider: PGVECTOR
base: base:
url: ${S2_DB_HOST:127.0.0.1} url: 127.0.0.1
port: ${S2_DB_PORT:5432} port: 5432
databaseName: ${S2_DB_DATABASE:postgres} databaseName: postgres
user: ${S2_DB_USER:postgres} user: postgres
password: ${S2_DB_PASSWORD:postgres} password: postgres
dimension: 512 dimension: 512

View File

@@ -41,3 +41,5 @@ s2:
threshold: 0.5 threshold: 0.5
min: min:
threshold: 0.3 threshold: 0.3
embedding:
use-llm-enhance: true

View File

@@ -209,5 +209,6 @@
}, },
"engines": { "engines": {
"node": ">=16" "node": ">=16"
} },
"packageManager": "pnpm@9.12.3+sha512.cce0f9de9c5a7c95bef944169cc5dfe8741abfb145078c0d508b868056848a87c81e626246cb60967cbd7fd29a6c062ef73ff840d96b3c86c40ac92cf4a813ee"
} }

View File

@@ -9,7 +9,7 @@ import {
RangeValue, RangeValue,
SimilarQuestionType, SimilarQuestionType,
} from '../../common/type'; } from '../../common/type';
import { createContext, useEffect, useRef, useState } from 'react'; import { createContext, useEffect, useState } from 'react';
import { chatExecute, chatParse, queryData, deleteQuery, switchEntity } from '../../service'; import { chatExecute, chatParse, queryData, deleteQuery, switchEntity } from '../../service';
import { PARSE_ERROR_TIP, PREFIX_CLS, SEARCH_EXCEPTION_TIP } from '../../common/constants'; import { PARSE_ERROR_TIP, PREFIX_CLS, SEARCH_EXCEPTION_TIP } from '../../common/constants';
import { message, Spin } from 'antd'; import { message, Spin } from 'antd';
@@ -490,9 +490,7 @@ const ChatItem: React.FC<Props> = ({
onSwitchEntity={onSwitchEntity} onSwitchEntity={onSwitchEntity}
onFiltersChange={onFiltersChange} onFiltersChange={onFiltersChange}
onDateInfoChange={onDateInfoChange} onDateInfoChange={onDateInfoChange}
onRefresh={() => { onRefresh={onRefresh}
onRefresh();
}}
handlePresetClick={handlePresetClick} handlePresetClick={handlePresetClick}
/> />
)} )}

View File

@@ -40,6 +40,12 @@ const BarChart: React.FC<Props> = ({
}) => { }) => {
const chartRef = useRef<any>(); const chartRef = useRef<any>();
const instanceRef = useRef<ECharts>(); const instanceRef = useRef<ECharts>();
const { downloadChartAsImage } = useExportByEcharts({
instanceRef,
question,
});
const { register } = useContext(ChartItemContext);
const { queryColumns, queryResults, entityInfo } = data; const { queryColumns, queryResults, entityInfo } = data;
@@ -189,13 +195,6 @@ const BarChart: React.FC<Props> = ({
const prefixCls = `${PREFIX_CLS}-bar`; const prefixCls = `${PREFIX_CLS}-bar`;
const { downloadChartAsImage } = useExportByEcharts({
instanceRef,
question,
});
const { register } = useContext(ChartItemContext);
register('downloadChartAsImage', downloadChartAsImage); register('downloadChartAsImage', downloadChartAsImage);
return ( return (

View File

@@ -93,7 +93,9 @@ export const getFormattedValue = (value: number | string, remainZero?: boolean)
export const formatNumberWithCN = (num: number) => { export const formatNumberWithCN = (num: number) => {
if (isNaN(num)) return '-'; if (isNaN(num)) return '-';
if (num >= 10000) { if (num >= 100000000) {
return (num / 100000000).toFixed(1) + '亿';
} else if (num >= 10000) {
return (num / 10000).toFixed(1) + '万'; return (num / 10000).toFixed(1) + '万';
} else { } else {
return formatByDecimalPlaces(num, 2); return formatByDecimalPlaces(num, 2);

View File

@@ -4,5 +4,9 @@ export default {
target: 'http://127.0.0.1:9080', target: 'http://127.0.0.1:9080',
changeOrigin: true, changeOrigin: true,
}, },
'/aibi/api/': {
target: 'http://127.0.0.1:9080',
changeOrigin: true,
},
}, },
}; };

View File

@@ -12,6 +12,7 @@ import { ISemantic } from '../../data';
import { ColumnsConfig } from '../../components/TableColumnRender'; import { ColumnsConfig } from '../../components/TableColumnRender';
import ViewSearchFormModal from './ViewSearchFormModal'; import ViewSearchFormModal from './ViewSearchFormModal';
import { toDatasetEditPage } from '@/pages/SemanticModel/utils'; import { toDatasetEditPage } from '@/pages/SemanticModel/utils';
import UploadFile from './UploadFile';
type Props = { type Props = {
// dataSetList: ISemantic.IDatasetItem[]; // dataSetList: ISemantic.IDatasetItem[];
@@ -92,9 +93,6 @@ const DataSetTable: React.FC<Props> = ({ disabledEdit = false }) => {
<a <a
onClick={() => { onClick={() => {
toDatasetEditPage(record.domainId, record.id, 'relation'); toDatasetEditPage(record.domainId, record.id, 'relation');
// setEditFormStep(1);
// setViewItem(record);
// setCreateDataSourceModalOpen(true);
}} }}
> >
{name} {name}
@@ -146,9 +144,6 @@ const DataSetTable: React.FC<Props> = ({ disabledEdit = false }) => {
key="metricEditBtn" key="metricEditBtn"
onClick={() => { onClick={() => {
toDatasetEditPage(record.domainId, record.id); toDatasetEditPage(record.domainId, record.id);
// setEditFormStep(0);
// setViewItem(record);
// setCreateDataSourceModalOpen(true);
}} }}
> >
@@ -189,6 +184,12 @@ const DataSetTable: React.FC<Props> = ({ disabledEdit = false }) => {
</Button> </Button>
)} )}
<UploadFile
key="uploadFile"
buttonType="link"
domainId={record.domainId}
datasetId={record.id}
/>
<Popconfirm <Popconfirm
title="确认删除?" title="确认删除?"
okText="是" okText="是"
@@ -229,6 +230,13 @@ const DataSetTable: React.FC<Props> = ({ disabledEdit = false }) => {
disabledEdit disabledEdit
? [<></>] ? [<></>]
: [ : [
<UploadFile
key="uploadFile"
domainId={selectDomainId}
onFileUploaded={() => {
queryDataSetList();
}}
/>,
<Button <Button
key="create" key="create"
type="primary" type="primary"

View File

@@ -0,0 +1,44 @@
import { getToken } from '@/utils/utils';
import { UploadOutlined } from '@ant-design/icons';
import type { UploadProps } from 'antd';
import { Button, message, Upload } from 'antd';
type Props = {
buttonType?: string;
domainId?: number;
datasetId?: string;
onFileUploaded?: () => void;
};
const UploadFile = ({ buttonType, domainId, datasetId, onFileUploaded }: Props) => {
const props: UploadProps = {
name: 'multipartFile',
action: `/aibi/api/data/file/uploadFileNew?type=DATASET&domainId=${domainId}${
datasetId ? `&dataSetId=${datasetId}` : ''
}`,
showUploadList: false,
onChange(info) {
if (info.file.status !== 'uploading') {
console.log(info.file, info.fileList);
}
if (info.file.status === 'done') {
message.success('导入成功');
onFileUploaded?.();
} else if (info.file.status === 'error') {
message.error('导入失败');
}
},
};
return (
<Upload {...props}>
{buttonType === 'link' ? (
<a></a>
) : (
<Button icon={<UploadOutlined />}></Button>
)}
</Upload>
);
};
export default UploadFile;

28637
webapp/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff